aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMark Hamstra <markhamstra@gmail.com>2013-03-16 11:54:44 -0700
committerMark Hamstra <markhamstra@gmail.com>2013-03-16 11:54:44 -0700
commit38454c4aedcc4b454d3470ed853d4741ca920db2 (patch)
treeaf81b3b052a5563d5613cba9cfdba728e5cdc8d0 /core
parentcd5b947cf64ce0c8abb4b4bf5f37550522eac8e1 (diff)
parentc1e9cdc49f89222b366a14a20ffd937ca0fb9adc (diff)
downloadspark-38454c4aedcc4b454d3470ed853d4741ca920db2.tar.gz
spark-38454c4aedcc4b454d3470ed853d4741ca920db2.tar.bz2
spark-38454c4aedcc4b454d3470ed853d4741ca920db2.zip
Merge branch 'master' of https://github.com/mesos/spark into WithThing
Diffstat (limited to 'core')
-rw-r--r--core/pom.xml6
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala2
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala17
-rw-r--r--core/src/main/scala/spark/RDD.scala18
-rw-r--r--core/src/main/scala/spark/executor/TaskMetrics.scala4
-rw-r--r--core/src/main/scala/spark/rdd/HadoopRDD.scala33
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala68
-rw-r--r--core/src/main/scala/spark/scheduler/SparkListener.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala10
-rw-r--r--core/src/main/scala/spark/serializer/Serializer.scala32
-rw-r--r--core/src/main/scala/spark/storage/BlockFetchTracker.scala2
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala6
-rw-r--r--core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala2
-rw-r--r--core/src/main/scala/spark/util/NextIterator.scala71
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala28
-rw-r--r--core/src/test/scala/spark/scheduler/SparkListenerSuite.scala86
-rw-r--r--core/src/test/scala/spark/util/NextIteratorSuite.scala68
17 files changed, 353 insertions, 104 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 9d46d94c1c..fe9c803728 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -3,7 +3,7 @@
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.spark-project</groupId>
- <artifactId>parent</artifactId>
+ <artifactId>spark-parent</artifactId>
<version>0.7.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
@@ -87,6 +87,10 @@
<groupId>org.apache.mesos</groupId>
<artifactId>mesos</artifactId>
</dependency>
+ <dependency>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ </dependency>
<dependency>
<groupId>org.scalatest</groupId>
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index 53b0389c3a..c27ed36406 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -55,7 +55,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleReadMillis = itr.getNetMillis
shuffleMetrics.remoteFetchTime = itr.remoteFetchTime
- shuffleMetrics.remoteFetchWaitTime = itr.remoteFetchWaitTime
+ shuffleMetrics.fetchWaitTime = itr.fetchWaitTime
shuffleMetrics.remoteBytesRead = itr.remoteBytesRead
shuffleMetrics.totalBlocksFetched = itr.totalBlocks
shuffleMetrics.localBlocksFetched = itr.numLocalBlocks
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index e7408e4352..3d1b1ca268 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -441,6 +441,23 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
/**
+ * Return an RDD with the pairs from `this` whose keys are not in `other`.
+ *
+ * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+ * RDD will be <= us.
+ */
+ def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] =
+ subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size)))
+
+ /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+ def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
+ subtractByKey(other, new HashPartitioner(numPartitions))
+
+ /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+ def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
+ new SubtractedRDD[K, V, W](self, other, p)
+
+ /**
* Return the list of values in the RDD for key `key`. This operation is done efficiently if the
* RDD has a known partitioner by only searching the partition that the key maps to.
*/
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 2ad11bc604..b6defdd418 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -484,7 +484,23 @@ abstract class RDD[T: ClassManifest](
/**
* Return an RDD with the elements from `this` that are not in `other`.
*/
- def subtract(other: RDD[T], p: Partitioner): RDD[T] = new SubtractedRDD[T](this, other, p)
+ def subtract(other: RDD[T], p: Partitioner): RDD[T] = {
+ if (partitioner == Some(p)) {
+ // Our partitioner knows how to handle T (which, since we have a partitioner, is
+ // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples
+ val p2 = new Partitioner() {
+ override def numPartitions = p.numPartitions
+ override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1)
+ }
+ // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies
+ // anyway, and when calling .keys, will not have a partitioner set, even though
+ // the SubtractedRDD will, thanks to p2's de-tupled partitioning, already be
+ // partitioned by the right/real keys (e.g. p).
+ this.map(x => (x, null)).subtractByKey(other.map((_, null)), p2).keys
+ } else {
+ this.map(x => (x, null)).subtractByKey(other.map((_, null)), p).keys
+ }
+ }
/**
* Reduces the elements of this RDD using the specified commutative and associative binary operator.
diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala
index b9c07830f5..93bbb6b458 100644
--- a/core/src/main/scala/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/spark/executor/TaskMetrics.scala
@@ -54,9 +54,9 @@ class ShuffleReadMetrics extends Serializable {
var shuffleReadMillis: Long = _
/**
- * Total time that is spent blocked waiting for shuffle to fetch remote data
+ * Total time that is spent blocked waiting for shuffle to fetch data
*/
- var remoteFetchWaitTime: Long = _
+ var fetchWaitTime: Long = _
/**
* The total amount of time for all the shuffle fetches. This adds up time from overlapping
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index 78097502bc..a6322dc58d 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -16,6 +16,7 @@ import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext}
+import spark.util.NextIterator
/**
@@ -62,7 +63,7 @@ class HadoopRDD[K, V](
.asInstanceOf[InputFormat[K, V]]
}
- override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
+ override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopPartition]
var reader: RecordReader[K, V] = null
@@ -71,38 +72,22 @@ class HadoopRDD[K, V](
reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback{ () => close() }
+ context.addOnCompleteCallback{ () => closeIfNeeded() }
val key: K = reader.createKey()
val value: V = reader.createValue()
- var gotNext = false
- var finished = false
-
- override def hasNext: Boolean = {
- if (!gotNext) {
- try {
- finished = !reader.next(key, value)
- } catch {
- case eof: EOFException =>
- finished = true
- }
- gotNext = true
- }
- !finished
- }
- override def next: (K, V) = {
- if (!gotNext) {
+ override def getNext() = {
+ try {
finished = !reader.next(key, value)
+ } catch {
+ case eof: EOFException =>
+ finished = true
}
- if (finished) {
- throw new NoSuchElementException("End of stream")
- }
- gotNext = false
(key, value)
}
- private def close() {
+ override def close() {
try {
reader.close()
} catch {
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 43ec90cac5..0a02561062 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -1,7 +1,8 @@
package spark.rdd
-import java.util.{HashSet => JHashSet}
+import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
import spark.RDD
import spark.Partitioner
import spark.Dependency
@@ -27,10 +28,10 @@ import spark.OneToOneDependency
* you can use `rdd1`'s partitioner/partition size and not worry about running
* out of memory because of the size of `rdd2`.
*/
-private[spark] class SubtractedRDD[T: ClassManifest](
- @transient var rdd1: RDD[T],
- @transient var rdd2: RDD[T],
- part: Partitioner) extends RDD[T](rdd1.context, Nil) {
+private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
+ @transient var rdd1: RDD[(K, V)],
+ @transient var rdd2: RDD[(K, W)],
+ part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
override def getDependencies: Seq[Dependency[_]] = {
Seq(rdd1, rdd2).map { rdd =>
@@ -39,26 +40,7 @@ private[spark] class SubtractedRDD[T: ClassManifest](
new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
- val mapSideCombinedRDD = rdd.mapPartitions(i => {
- val set = new JHashSet[T]()
- while (i.hasNext) {
- set.add(i.next)
- }
- set.iterator
- }, true)
- // ShuffleDependency requires a tuple (k, v), which it will partition by k.
- // We need this to partition to map to the same place as the k for
- // OneToOneDependency, which means:
- // - for already-tupled RDD[(A, B)], into getPartition(a)
- // - for non-tupled RDD[C], into getPartition(c)
- val part2 = new Partitioner() {
- def numPartitions = part.numPartitions
- def getPartition(key: Any) = key match {
- case (k, v) => part.getPartition(k)
- case k => part.getPartition(k)
- }
- }
- new ShuffleDependency(mapSideCombinedRDD.map((_, null)), part2)
+ new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part)
}
}
}
@@ -81,22 +63,32 @@ private[spark] class SubtractedRDD[T: ClassManifest](
override val partitioner = Some(part)
- override def compute(p: Partition, context: TaskContext): Iterator[T] = {
+ override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
- val set = new JHashSet[T]
- def integrate(dep: CoGroupSplitDep, op: T => Unit) = dep match {
+ val map = new JHashMap[K, ArrayBuffer[V]]
+ def getSeq(k: K): ArrayBuffer[V] = {
+ val seq = map.get(k)
+ if (seq != null) {
+ seq
+ } else {
+ val seq = new ArrayBuffer[V]()
+ map.put(k, seq)
+ seq
+ }
+ }
+ def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
- for (k <- rdd.iterator(itsSplit, context))
- op(k.asInstanceOf[T])
+ for (t <- rdd.iterator(itsSplit, context))
+ op(t.asInstanceOf[(K, V)])
case ShuffleCoGroupSplitDep(shuffleId) =>
- for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
- op(k.asInstanceOf[T])
+ for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
+ op(t.asInstanceOf[(K, V)])
}
- // the first dep is rdd1; add all keys to the set
- integrate(partition.deps(0), set.add)
- // the second dep is rdd2; remove all of its keys from the set
- integrate(partition.deps(1), set.remove)
- set.iterator
+ // the first dep is rdd1; add all values to the map
+ integrate(partition.deps(0), t => getSeq(t._1) += t._2)
+ // the second dep is rdd2; remove all of its keys
+ integrate(partition.deps(1), t => map.remove(t._1))
+ map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten
}
override def clearDependencies() {
@@ -105,4 +97,4 @@ private[spark] class SubtractedRDD[T: ClassManifest](
rdd2 = null
}
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala
index 21185227ab..a65140b145 100644
--- a/core/src/main/scala/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/spark/scheduler/SparkListener.scala
@@ -31,7 +31,7 @@ class StatsReportListener extends SparkListener with Logging {
showBytesDistribution("shuffle bytes written:",(_,metric) => metric.shuffleWriteMetrics.map{_.shuffleBytesWritten})
//fetch & io
- showMillisDistribution("fetch wait time:",(_, metric) => metric.shuffleReadMetrics.map{_.remoteFetchWaitTime})
+ showMillisDistribution("fetch wait time:",(_, metric) => metric.shuffleReadMetrics.map{_.fetchWaitTime})
showBytesDistribution("remote bytes read:", (_, metric) => metric.shuffleReadMetrics.map{_.remoteBytesRead})
showBytesDistribution("task result size:", (_, metric) => Some(metric.resultSize))
@@ -137,7 +137,7 @@ case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], othe
object RuntimePercentage {
def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
val denom = totalTime.toDouble
- val fetchTime = metrics.shuffleReadMetrics.map{_.remoteFetchWaitTime}
+ val fetchTime = metrics.shuffleReadMetrics.map{_.fetchWaitTime}
val fetch = fetchTime.map{_ / denom}
val exec = (metrics.executorRunTime - fetchTime.getOrElse(0l)) / denom
val other = 1.0 - (exec + fetch.getOrElse(0d))
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index a76253ea14..9e1bde3fbe 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -67,8 +67,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
updateDependencies(taskFiles, taskJars) // Download any files added with addFile
+ val deserStart = System.currentTimeMillis()
val deserializedTask = ser.deserialize[Task[_]](
taskBytes, Thread.currentThread.getContextClassLoader)
+ val deserTime = System.currentTimeMillis() - deserStart
// Run it
val result: Any = deserializedTask.run(attemptId)
@@ -77,15 +79,19 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
- val resultToReturn = ser.deserialize[Any](ser.serialize(result))
+ val serResult = ser.serialize(result)
+ deserializedTask.metrics.get.resultSize = serResult.limit()
+ val resultToReturn = ser.deserialize[Any](serResult)
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished " + task)
info.markSuccessful()
+ deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough
+ deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
// If the threadpool has not already been shutdown, notify DAGScheduler
if (!Thread.currentThread().isInterrupted)
- listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, null)
+ listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null))
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala
index 50b086125a..aca86ab6f0 100644
--- a/core/src/main/scala/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/spark/serializer/Serializer.scala
@@ -72,40 +72,18 @@ trait DeserializationStream {
* Read the elements of this stream through an iterator. This can only be called once, as
* reading each element will consume data from the input source.
*/
- def asIterator: Iterator[Any] = new Iterator[Any] {
- var gotNext = false
- var finished = false
- var nextValue: Any = null
-
- private def getNext() {
+ def asIterator: Iterator[Any] = new spark.util.NextIterator[Any] {
+ override protected def getNext() = {
try {
- nextValue = readObject[Any]()
+ readObject[Any]()
} catch {
case eof: EOFException =>
finished = true
}
- gotNext = true
}
- override def hasNext: Boolean = {
- if (!gotNext) {
- getNext()
- }
- if (finished) {
- close()
- }
- !finished
- }
-
- override def next(): Any = {
- if (!gotNext) {
- getNext()
- }
- if (finished) {
- throw new NoSuchElementException("End of stream")
- }
- gotNext = false
- nextValue
+ override protected def close() {
+ DeserializationStream.this.close()
}
}
}
diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
index ababb04305..993aece1f7 100644
--- a/core/src/main/scala/spark/storage/BlockFetchTracker.scala
+++ b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
@@ -5,6 +5,6 @@ private[spark] trait BlockFetchTracker {
def numLocalBlocks: Int
def numRemoteBlocks: Int
def remoteFetchTime : Long
- def remoteFetchWaitTime: Long
+ def fetchWaitTime: Long
def remoteBytesRead : Long
}
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 3118d3d412..210061e972 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -903,7 +903,7 @@ class BlockFetcherIterator(
private var _remoteBytesRead = 0l
private var _remoteFetchTime = 0l
- private var _remoteFetchWaitTime = 0l
+ private var _fetchWaitTime = 0l
if (blocksByAddress == null) {
throw new IllegalArgumentException("BlocksByAddress is null")
@@ -1046,7 +1046,7 @@ class BlockFetcherIterator(
val startFetchWait = System.currentTimeMillis()
val result = results.take()
val stopFetchWait = System.currentTimeMillis()
- _remoteFetchWaitTime += (stopFetchWait - startFetchWait)
+ _fetchWaitTime += (stopFetchWait - startFetchWait)
bytesInFlight -= result.size
while (!fetchRequests.isEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
@@ -1061,7 +1061,7 @@ class BlockFetcherIterator(
def numRemoteBlocks = remoteBlockIds.size
def remoteFetchTime = _remoteFetchTime
- def remoteFetchWaitTime = _remoteFetchWaitTime
+ def fetchWaitTime = _fetchWaitTime
def remoteBytesRead = _remoteBytesRead
diff --git a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala b/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
index 5c491877ba..f6c28dce52 100644
--- a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
+++ b/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
@@ -7,6 +7,6 @@ private[spark] trait DelegateBlockFetchTracker extends BlockFetchTracker {
def numLocalBlocks = delegate.numLocalBlocks
def numRemoteBlocks = delegate.numRemoteBlocks
def remoteFetchTime = delegate.remoteFetchTime
- def remoteFetchWaitTime = delegate.remoteFetchWaitTime
+ def fetchWaitTime = delegate.fetchWaitTime
def remoteBytesRead = delegate.remoteBytesRead
}
diff --git a/core/src/main/scala/spark/util/NextIterator.scala b/core/src/main/scala/spark/util/NextIterator.scala
new file mode 100644
index 0000000000..48b5018ddd
--- /dev/null
+++ b/core/src/main/scala/spark/util/NextIterator.scala
@@ -0,0 +1,71 @@
+package spark.util
+
+/** Provides a basic/boilerplate Iterator implementation. */
+private[spark] abstract class NextIterator[U] extends Iterator[U] {
+
+ private var gotNext = false
+ private var nextValue: U = _
+ private var closed = false
+ protected var finished = false
+
+ /**
+ * Method for subclasses to implement to provide the next element.
+ *
+ * If no next element is available, the subclass should set `finished`
+ * to `true` and may return any value (it will be ignored).
+ *
+ * This convention is required because `null` may be a valid value,
+ * and using `Option` seems like it might create unnecessary Some/None
+ * instances, given some iterators might be called in a tight loop.
+ *
+ * @return U, or set 'finished' when done
+ */
+ protected def getNext(): U
+
+ /**
+ * Method for subclasses to implement when all elements have been successfully
+ * iterated, and the iteration is done.
+ *
+ * <b>Note:</b> `NextIterator` cannot guarantee that `close` will be
+ * called because it has no control over what happens when an exception
+ * happens in the user code that is calling hasNext/next.
+ *
+ * Ideally you should have another try/catch, as in HadoopRDD, that
+ * ensures any resources are closed should iteration fail.
+ */
+ protected def close()
+
+ /**
+ * Calls the subclass-defined close method, but only once.
+ *
+ * Usually calling `close` multiple times should be fine, but historically
+ * there have been issues with some InputFormats throwing exceptions.
+ */
+ def closeIfNeeded() {
+ if (!closed) {
+ close()
+ closed = true
+ }
+ }
+
+ override def hasNext: Boolean = {
+ if (!finished) {
+ if (!gotNext) {
+ nextValue = getNext()
+ if (finished) {
+ closeIfNeeded()
+ }
+ gotNext = true
+ }
+ }
+ !finished
+ }
+
+ override def next(): U = {
+ if (!hasNext) {
+ throw new NoSuchElementException("End of stream")
+ }
+ gotNext = false
+ nextValue
+ }
+} \ No newline at end of file
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 8411291b2c..2b2a90defa 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -272,13 +272,39 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
}
// partitionBy so we have a narrow dependency
val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
- println(sc.runJob(a, (i: Iterator[(Int, String)]) => i.toList).toList)
// more partitions/no partitioner so a shuffle dependency
val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
val c = a.subtract(b)
assert(c.collect().toSet === Set((1, "a"), (3, "c")))
+ // Ideally we could keep the original partitioner...
+ assert(c.partitioner === None)
+ }
+
+ test("subtractByKey") {
+ sc = new SparkContext("local", "test")
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
+ val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtractByKey with narrow dependency") {
+ sc = new SparkContext("local", "test")
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
assert(c.partitioner.get === p)
}
+
}
object ShuffleSuite {
diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
new file mode 100644
index 0000000000..2f5af10e69
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
@@ -0,0 +1,86 @@
+package spark.scheduler
+
+import org.scalatest.FunSuite
+import spark.{SparkContext, LocalSparkContext}
+import scala.collection.mutable
+import org.scalatest.matchers.ShouldMatchers
+import spark.SparkContext._
+
+/**
+ *
+ */
+
+class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+
+ test("local metrics") {
+ sc = new SparkContext("local[4]", "test")
+ val listener = new SaveStageInfo
+ sc.addSparkListener(listener)
+ sc.addSparkListener(new StatsReportListener)
+ //just to make sure some of the tasks take a noticeable amount of time
+ val w = {i:Int =>
+ if (i == 0)
+ Thread.sleep(100)
+ i
+ }
+
+ val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
+ d.count
+ listener.stageInfos.size should be (1)
+
+ val d2 = d.map{i => w(i) -> i * 2}.setName("shuffle input 1")
+
+ val d3 = d.map{i => w(i) -> (0 to (i % 5))}.setName("shuffle input 2")
+
+ val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)}
+ d4.setName("A Cogroup")
+
+ d4.collectAsMap
+
+ listener.stageInfos.size should be (4)
+ listener.stageInfos.foreach {stageInfo =>
+ //small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms
+ checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration")
+ checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime")
+ checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime")
+ if (stageInfo.stage.rdd.name == d4.name) {
+ checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime")
+ }
+
+ stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) =>
+ taskMetrics.resultSize should be > (0l)
+ if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
+ taskMetrics.shuffleWriteMetrics should be ('defined)
+ taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
+ }
+ if (stageInfo.stage.rdd.name == d4.name) {
+ taskMetrics.shuffleReadMetrics should be ('defined)
+ val sm = taskMetrics.shuffleReadMetrics.get
+ sm.totalBlocksFetched should be > (0)
+ sm.shuffleReadMillis should be > (0l)
+ sm.localBlocksFetched should be > (0)
+ sm.remoteBlocksFetched should be (0)
+ sm.remoteBytesRead should be (0l)
+ sm.remoteFetchTime should be (0l)
+ }
+ }
+ }
+ }
+
+ def checkNonZeroAvg(m: Traversable[Long], msg: String) {
+ assert(m.sum / m.size.toDouble > 0.0, msg)
+ }
+
+ def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = {
+ val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name}
+ !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty
+ }
+
+ class SaveStageInfo extends SparkListener {
+ val stageInfos = mutable.Buffer[StageInfo]()
+ def onStageCompleted(stage: StageCompleted) {
+ stageInfos += stage.stageInfo
+ }
+ }
+
+}
diff --git a/core/src/test/scala/spark/util/NextIteratorSuite.scala b/core/src/test/scala/spark/util/NextIteratorSuite.scala
new file mode 100644
index 0000000000..ed5b36da73
--- /dev/null
+++ b/core/src/test/scala/spark/util/NextIteratorSuite.scala
@@ -0,0 +1,68 @@
+package spark.util
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import scala.collection.mutable.Buffer
+import java.util.NoSuchElementException
+
+class NextIteratorSuite extends FunSuite with ShouldMatchers {
+ test("one iteration") {
+ val i = new StubIterator(Buffer(1))
+ i.hasNext should be === true
+ i.next should be === 1
+ i.hasNext should be === false
+ intercept[NoSuchElementException] { i.next() }
+ }
+
+ test("two iterations") {
+ val i = new StubIterator(Buffer(1, 2))
+ i.hasNext should be === true
+ i.next should be === 1
+ i.hasNext should be === true
+ i.next should be === 2
+ i.hasNext should be === false
+ intercept[NoSuchElementException] { i.next() }
+ }
+
+ test("empty iteration") {
+ val i = new StubIterator(Buffer())
+ i.hasNext should be === false
+ intercept[NoSuchElementException] { i.next() }
+ }
+
+ test("close is called once for empty iterations") {
+ val i = new StubIterator(Buffer())
+ i.hasNext should be === false
+ i.hasNext should be === false
+ i.closeCalled should be === 1
+ }
+
+ test("close is called once for non-empty iterations") {
+ val i = new StubIterator(Buffer(1, 2))
+ i.next should be === 1
+ i.next should be === 2
+ // close isn't called until we check for the next element
+ i.closeCalled should be === 0
+ i.hasNext should be === false
+ i.closeCalled should be === 1
+ i.hasNext should be === false
+ i.closeCalled should be === 1
+ }
+
+ class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] {
+ var closeCalled = 0
+
+ override def getNext() = {
+ if (ints.size == 0) {
+ finished = true
+ 0
+ } else {
+ ints.remove(0)
+ }
+ }
+
+ override def close() {
+ closeCalled += 1
+ }
+ }
+}