aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/compute-classpath.sh22
-rwxr-xr-xbin/slaves.sh19
-rwxr-xr-xbin/spark-daemon.sh21
-rwxr-xr-xbin/spark-daemons.sh2
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala247
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala40
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/BroadcastSuite.scala52
-rw-r--r--core/src/test/scala/org/apache/spark/JobCancellationSuite.scala32
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala49
-rw-r--r--docs/configuration.md8
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala15
-rw-r--r--project/SparkBuild.scala13
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala9
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala4
28 files changed, 554 insertions, 146 deletions
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index c7819d4932..c16afd6b36 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -32,12 +32,26 @@ fi
# Build up classpath
CLASSPATH="$SPARK_CLASSPATH:$FWDIR/conf"
-if [ -f "$FWDIR/RELEASE" ]; then
- ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
+
+# First check if we have a dependencies jar. If so, include binary classes with the deps jar
+if [ -f "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar ]; then
+ CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes"
+
+ DEPS_ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar`
+ CLASSPATH="$CLASSPATH:$DEPS_ASSEMBLY_JAR"
else
- ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
+ # Else use spark-assembly jar from either RELEASE or assembly directory
+ if [ -f "$FWDIR/RELEASE" ]; then
+ ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
+ else
+ ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
+ fi
+ CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
fi
-CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1
if [[ $SPARK_TESTING == 1 ]]; then
diff --git a/bin/slaves.sh b/bin/slaves.sh
index 752565b759..c367c2fd8e 100755
--- a/bin/slaves.sh
+++ b/bin/slaves.sh
@@ -28,7 +28,7 @@
# SPARK_SSH_OPTS Options passed to ssh when running remote commands.
##
-usage="Usage: slaves.sh [--config confdir] command..."
+usage="Usage: slaves.sh [--config <conf-dir>] command..."
# if no args specified, show usage
if [ $# -le 0 ]; then
@@ -46,6 +46,23 @@ bin=`cd "$bin"; pwd`
# spark-env.sh. Save it here.
HOSTLIST=$SPARK_SLAVES
+# Check if --config is passed as an argument. It is an optional parameter.
+# Exit if the argument is not a directory.
+if [ "$1" == "--config" ]
+then
+ shift
+ conf_dir=$1
+ if [ ! -d "$conf_dir" ]
+ then
+ echo "ERROR : $conf_dir is not a directory"
+ echo $usage
+ exit 1
+ else
+ export SPARK_CONF_DIR=$conf_dir
+ fi
+ shift
+fi
+
if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then
. "${SPARK_CONF_DIR}/spark-env.sh"
fi
diff --git a/bin/spark-daemon.sh b/bin/spark-daemon.sh
index 5bfe967fbf..a0c0d44b58 100755
--- a/bin/spark-daemon.sh
+++ b/bin/spark-daemon.sh
@@ -29,7 +29,7 @@
# SPARK_NICENESS The scheduling priority for daemons. Defaults to 0.
##
-usage="Usage: spark-daemon.sh [--config <conf-dir>] [--hosts hostlistfile] (start|stop) <spark-command> <spark-instance-number> <args...>"
+usage="Usage: spark-daemon.sh [--config <conf-dir>] (start|stop) <spark-command> <spark-instance-number> <args...>"
# if no args specified, show usage
if [ $# -le 1 ]; then
@@ -43,6 +43,25 @@ bin=`cd "$bin"; pwd`
. "$bin/spark-config.sh"
# get arguments
+
+# Check if --config is passed as an argument. It is an optional parameter.
+# Exit if the argument is not a directory.
+
+if [ "$1" == "--config" ]
+then
+ shift
+ conf_dir=$1
+ if [ ! -d "$conf_dir" ]
+ then
+ echo "ERROR : $conf_dir is not a directory"
+ echo $usage
+ exit 1
+ else
+ export SPARK_CONF_DIR=$conf_dir
+ fi
+ shift
+fi
+
startStop=$1
shift
command=$1
diff --git a/bin/spark-daemons.sh b/bin/spark-daemons.sh
index 354eb905a1..64286cb2da 100755
--- a/bin/spark-daemons.sh
+++ b/bin/spark-daemons.sh
@@ -19,7 +19,7 @@
# Run a Spark command on all slave hosts.
-usage="Usage: spark-daemons.sh [--config confdir] [--hosts hostlistfile] [start|stop] command instance-number args..."
+usage="Usage: spark-daemons.sh [--config <conf-dir>] [start|stop] command instance-number args..."
# if no args specified, show usage
if [ $# -le 1 ]; then
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 0aafc0a2fc..48bbc78795 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -217,21 +217,20 @@ class SparkContext(
scheduler.initialize(backend)
scheduler
- case _ =>
- if (MESOS_REGEX.findFirstIn(master).isEmpty) {
- logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
- }
+ case MESOS_REGEX(mesosUrl) =>
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
- val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
+ new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName)
} else {
- new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
+ new MesosSchedulerBackend(scheduler, this, mesosUrl, appName)
}
scheduler.initialize(backend)
scheduler
+
+ case _ =>
+ throw new SparkException("Could not parse Master URL: '" + master + "'")
}
}
taskScheduler.start()
@@ -288,8 +287,19 @@ class SparkContext(
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
/** Set a human readable description of the current job. */
+ @deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
- setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
+ setJobGroup("", value)
+ }
+
+ def setJobGroup(groupId: String, description: String) {
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
+ setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
+ }
+
+ def clearJobGroup() {
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
+ setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
}
// Post init
@@ -867,10 +877,14 @@ class SparkContext(
callSite,
allowLocal = false,
resultHandler,
- null)
+ localProperties.get)
new SimpleFutureAction(waiter, resultFunc)
}
+ def cancelJobGroup(groupId: String) {
+ dagScheduler.cancelJobGroup(groupId)
+ }
+
/**
* Cancel all jobs that have been scheduled or are running.
*/
@@ -934,8 +948,11 @@ class SparkContext(
* various Spark features.
*/
object SparkContext {
+
val SPARK_JOB_DESCRIPTION = "spark.job.description"
+ val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
+
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 2bab9d6e3d..afa76a4a76 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -36,7 +36,10 @@ import org.apache.spark.SerializableWritable
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
-class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable {
+class SparkHadoopWriter(@transient jobConf: JobConf)
+ extends Logging
+ with SparkHadoopMapRedUtil
+ with Serializable {
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 51584d686d..cae983ed4c 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.executor.TaskMetrics
class TaskContext(
- private[spark] val stageId: Int,
+ val stageId: Int,
val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
new file mode 100644
index 0000000000..073a0a5029
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.io._
+
+import scala.math
+import scala.util.Random
+
+import org.apache.spark._
+import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+
+private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+extends Broadcast[T](id) with Logging with Serializable {
+
+ def value = value_
+
+ def broadcastId = BroadcastBlockId(id)
+
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ }
+
+ @transient var arrayOfBlocks: Array[TorrentBlock] = null
+ @transient var totalBlocks = -1
+ @transient var totalBytes = -1
+ @transient var hasBlocks = 0
+
+ if (!isLocal) {
+ sendBroadcast()
+ }
+
+ def sendBroadcast() {
+ var tInfo = TorrentBroadcast.blockifyObject(value_)
+
+ totalBlocks = tInfo.totalBlocks
+ totalBytes = tInfo.totalBytes
+ hasBlocks = tInfo.totalBlocks
+
+ // Store meta-info
+ val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(
+ metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
+ }
+
+ // Store individual pieces
+ for (i <- 0 until totalBlocks) {
+ val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(
+ pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
+ }
+ }
+ }
+
+ // Called by JVM when deserializing an object
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(broadcastId) match {
+ case Some(x) =>
+ value_ = x.asInstanceOf[T]
+
+ case None =>
+ val start = System.nanoTime
+ logInfo("Started reading broadcast variable " + id)
+
+ // Initialize @transient variables that will receive garbage values from the master.
+ resetWorkerVariables()
+
+ if (receiveBroadcast(id)) {
+ value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
+
+ // Store the merged copy in cache so that the next worker doesn't need to rebuild it.
+ // This creates a tradeoff between memory usage and latency.
+ // Storing copy doubles the memory footprint; not storing doubles deserialization cost.
+ SparkEnv.get.blockManager.putSingle(
+ broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+
+ // Remove arrayOfBlocks from memory once value_ is on local cache
+ resetWorkerVariables()
+ } else {
+ logError("Reading broadcast variable " + id + " failed")
+ }
+
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ }
+ }
+ }
+
+ private def resetWorkerVariables() {
+ arrayOfBlocks = null
+ totalBytes = -1
+ totalBlocks = -1
+ hasBlocks = 0
+ }
+
+ def receiveBroadcast(variableID: Long): Boolean = {
+ // Receive meta-info
+ val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ var attemptId = 10
+ while (attemptId > 0 && totalBlocks == -1) {
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(metaId) match {
+ case Some(x) =>
+ val tInfo = x.asInstanceOf[TorrentInfo]
+ totalBlocks = tInfo.totalBlocks
+ totalBytes = tInfo.totalBytes
+ arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
+ hasBlocks = 0
+
+ case None =>
+ Thread.sleep(500)
+ }
+ }
+ attemptId -= 1
+ }
+ if (totalBlocks == -1) {
+ return false
+ }
+
+ // Receive actual blocks
+ val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
+ for (pid <- recvOrder) {
+ val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(pieceId) match {
+ case Some(x) =>
+ arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
+ hasBlocks += 1
+ SparkEnv.get.blockManager.putSingle(
+ pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
+
+ case None =>
+ throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
+ }
+ }
+ }
+
+ (hasBlocks == totalBlocks)
+ }
+
+}
+
+private object TorrentBroadcast
+extends Logging {
+
+ private var initialized = false
+
+ def initialize(_isDriver: Boolean) {
+ synchronized {
+ if (!initialized) {
+ initialized = true
+ }
+ }
+ }
+
+ def stop() {
+ initialized = false
+ }
+
+ val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024
+
+ def blockifyObject[T](obj: T): TorrentInfo = {
+ val byteArray = Utils.serialize[T](obj)
+ val bais = new ByteArrayInputStream(byteArray)
+
+ var blockNum = (byteArray.length / BLOCK_SIZE)
+ if (byteArray.length % BLOCK_SIZE != 0)
+ blockNum += 1
+
+ var retVal = new Array[TorrentBlock](blockNum)
+ var blockID = 0
+
+ for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
+ val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
+ var tempByteArray = new Array[Byte](thisBlockSize)
+ val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
+
+ retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
+ blockID += 1
+ }
+ bais.close()
+
+ var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
+ tInfo.hasBlocks = blockNum
+
+ return tInfo
+ }
+
+ def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
+ totalBytes: Int,
+ totalBlocks: Int): T = {
+ var retByteArray = new Array[Byte](totalBytes)
+ for (i <- 0 until totalBlocks) {
+ System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
+ i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
+ }
+ Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
+ }
+
+}
+
+private[spark] case class TorrentBlock(
+ blockID: Int,
+ byteArray: Array[Byte])
+ extends Serializable
+
+private[spark] case class TorrentInfo(
+ @transient arrayOfBlocks : Array[TorrentBlock],
+ totalBlocks: Int,
+ totalBytes: Int)
+ extends Serializable {
+
+ @transient var hasBlocks = 0
+}
+
+private[spark] class TorrentBroadcastFactory
+ extends BroadcastFactory {
+
+ def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new TorrentBroadcast[T](value_, isLocal, id)
+
+ def stop() { TorrentBroadcast.stop() }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 7fb614402b..d84f5968df 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -55,20 +55,20 @@ class DAGScheduler(
mapOutputTracker: MapOutputTracker,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
- extends TaskSchedulerListener with Logging {
+ extends Logging {
def this(taskSched: TaskScheduler) {
this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
}
- taskSched.setListener(this)
+ taskSched.setDAGScheduler(this)
// Called by TaskScheduler to report task's starting.
- override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+ def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventQueue.put(BeginEvent(task, taskInfo))
}
// Called by TaskScheduler to report task completions or failures.
- override def taskEnded(
+ def taskEnded(
task: Task[_],
reason: TaskEndReason,
result: Any,
@@ -79,18 +79,18 @@ class DAGScheduler(
}
// Called by TaskScheduler when an executor fails.
- override def executorLost(execId: String) {
+ def executorLost(execId: String) {
eventQueue.put(ExecutorLost(execId))
}
// Called by TaskScheduler when a host is added
- override def executorGained(execId: String, host: String) {
+ def executorGained(execId: String, host: String) {
eventQueue.put(ExecutorGained(execId, host))
}
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
- override def taskSetFailed(taskSet: TaskSet, reason: String) {
+ def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
}
@@ -277,11 +277,6 @@ class DAGScheduler(
resultHandler: (Int, U) => Unit,
properties: Properties = null): JobWaiter[U] =
{
- val jobId = nextJobId.getAndIncrement()
- if (partitions.size == 0) {
- return new JobWaiter[U](this, jobId, 0, resultHandler)
- }
-
// Check to make sure we are not launching a task on a partition that does not exist.
val maxPartitions = rdd.partitions.length
partitions.find(p => p >= maxPartitions).foreach { p =>
@@ -290,6 +285,11 @@ class DAGScheduler(
"Total number of partitions: " + maxPartitions)
}
+ val jobId = nextJobId.getAndIncrement()
+ if (partitions.size == 0) {
+ return new JobWaiter[U](this, jobId, 0, resultHandler)
+ }
+
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
@@ -342,6 +342,11 @@ class DAGScheduler(
eventQueue.put(JobCancelled(jobId))
}
+ def cancelJobGroup(groupId: String) {
+ logInfo("Asked to cancel job group " + groupId)
+ eventQueue.put(JobGroupCancelled(groupId))
+ }
+
/**
* Cancel all jobs that are running or waiting in the queue.
*/
@@ -381,6 +386,17 @@ class DAGScheduler(
taskSched.cancelTasks(stage.id)
}
+ case JobGroupCancelled(groupId) =>
+ // Cancel all jobs belonging to this job group.
+ // First finds all active jobs with this group id, and then kill stages for them.
+ val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
+ .map(_.jobId)
+ if (!jobIds.isEmpty) {
+ running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+ }
+
case AllJobsCancelled =>
// Cancel all running jobs.
running.foreach { stage =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index ee89bfb38d..a5769c6041 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -46,6 +46,8 @@ private[scheduler] case class JobSubmitted(
private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
+private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
+
private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
private[scheduler]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 6a51efe8d6..10e0478108 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -24,8 +24,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
* Each TaskScheduler schedulers task for a single SparkContext.
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
* and are responsible for sending the tasks to the cluster, running them, retrying if there
- * are failures, and mitigating stragglers. They return events to the DAGScheduler through
- * the TaskSchedulerListener interface.
+ * are failures, and mitigating stragglers. They return events to the DAGScheduler.
*/
private[spark] trait TaskScheduler {
@@ -48,8 +47,8 @@ private[spark] trait TaskScheduler {
// Cancel a stage.
def cancelTasks(stageId: Int)
- // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
- def setListener(listener: TaskSchedulerListener): Unit
+ // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
+ def setDAGScheduler(dagScheduler: DAGScheduler): Unit
// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
def defaultParallelism(): Int
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
deleted file mode 100644
index 593fa9fb93..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import scala.collection.mutable.Map
-
-import org.apache.spark.TaskEndReason
-import org.apache.spark.executor.TaskMetrics
-
-/**
- * Interface for getting events back from the TaskScheduler.
- */
-private[spark] trait TaskSchedulerListener {
- // A task has started.
- def taskStarted(task: Task[_], taskInfo: TaskInfo)
-
- // A task has finished or failed.
- def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
- taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
-
- // A node was added to the cluster.
- def executorGained(execId: String, host: String): Unit
-
- // A node was lost from the cluster.
- def executorLost(execId: String): Unit
-
- // The TaskScheduler wants to abort an entire task set.
- def taskSetFailed(taskSet: TaskSet, reason: String): Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index 7a72ff0474..4ea8bf8853 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -79,7 +79,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
private val executorIdToHost = new HashMap[String, String]
// Listener object to pass upcalls into
- var listener: TaskSchedulerListener = null
+ var dagScheduler: DAGScheduler = null
var backend: SchedulerBackend = null
@@ -94,8 +94,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// This is a var so that we can reset it for testing purposes.
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
- override def setListener(listener: TaskSchedulerListener) {
- this.listener = listener
+ override def setDAGScheduler(dagScheduler: DAGScheduler) {
+ this.dagScheduler = dagScheduler
}
def initialize(context: SchedulerBackend) {
@@ -297,7 +297,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
// Update the DAGScheduler without holding a lock on this, since that can deadlock
if (failedExecutor != None) {
- listener.executorLost(failedExecutor.get)
+ dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
if (taskFailed) {
@@ -397,9 +397,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
logError("Lost an executor " + executorId + " (already removed): " + reason)
}
}
- // Call listener.executorLost without holding the lock on this to prevent deadlock
+ // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
if (failedExecutor != None) {
- listener.executorLost(failedExecutor.get)
+ dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
}
@@ -418,7 +418,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
def executorGained(execId: String, host: String) {
- listener.executorGained(execId, host)
+ dagScheduler.executorGained(execId, host)
}
def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 7bd3499300..29093e3b4f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -415,11 +415,11 @@ private[spark] class ClusterTaskSetManager(
}
private def taskStarted(task: Task[_], info: TaskInfo) {
- sched.listener.taskStarted(task, info)
+ sched.dagScheduler.taskStarted(task, info)
}
/**
- * Marks the task as successful and notifies the listener that a task has ended.
+ * Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
val info = taskInfos(tid)
@@ -429,7 +429,7 @@ private[spark] class ClusterTaskSetManager(
if (!successful(index)) {
logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
tid, info.duration, info.host, tasksSuccessful, numTasks))
- sched.listener.taskEnded(
+ sched.dagScheduler.taskEnded(
tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
// Mark successful and stop if all the tasks have succeeded.
@@ -445,7 +445,8 @@ private[spark] class ClusterTaskSetManager(
}
/**
- * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener.
+ * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
+ * DAG Scheduler.
*/
def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
val info = taskInfos(tid)
@@ -463,7 +464,7 @@ private[spark] class ClusterTaskSetManager(
reason.foreach {
case fetchFailed: FetchFailed =>
logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
successful(index) = true
tasksSuccessful += 1
sched.taskSetFinished(this)
@@ -472,11 +473,11 @@ private[spark] class ClusterTaskSetManager(
case TaskKilled =>
logWarning("Task %d was killed.".format(tid))
- sched.listener.taskEnded(tasks(index), reason.get, null, null, info, null)
+ sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
return
case ef: ExceptionFailure =>
- sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+ sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
val key = ef.description
val now = clock.getTime()
val (printFull, dupCount) = {
@@ -504,7 +505,7 @@ private[spark] class ClusterTaskSetManager(
case TaskResultLost =>
logWarning("Lost result for TID %s on host %s".format(tid, info.host))
- sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+ sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
case _ => {}
}
@@ -533,7 +534,7 @@ private[spark] class ClusterTaskSetManager(
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.listener.taskSetFailed(taskSet, message)
+ sched.dagScheduler.taskSetFailed(taskSet, message)
removeAllRunningTasks()
sched.taskSetFinished(this)
}
@@ -606,7 +607,7 @@ private[spark] class ClusterTaskSetManager(
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
- sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+ sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
index b445260d1b..2699f0b33e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -81,7 +81,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
val env = SparkEnv.get
val attemptId = new AtomicInteger
- var listener: TaskSchedulerListener = null
+ var dagScheduler: DAGScheduler = null
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
@@ -114,8 +114,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
}
- override def setListener(listener: TaskSchedulerListener) {
- this.listener = listener
+ override def setDAGScheduler(dagScheduler: DAGScheduler) {
+ this.dagScheduler = dagScheduler
}
override def submitTasks(taskSet: TaskSet) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
index f72e77d40f..55f8313e87 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
@@ -133,7 +133,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
def taskStarted(task: Task[_], info: TaskInfo) {
- sched.listener.taskStarted(task, info)
+ sched.dagScheduler.taskStarted(task, info)
}
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
@@ -148,7 +148,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
}
result.metrics.resultSize = serializedData.limit()
- sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
+ sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info,
+ result.metrics)
numFinished += 1
decreaseRunningTasks(1)
finished(index) = true
@@ -165,7 +166,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
decreaseRunningTasks(1)
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](
serializedData, getClass.getClassLoader)
- sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
+ sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
if (!finished(index)) {
copiesRunning(index) -= 1
numFailures(index) += 1
@@ -176,7 +177,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
taskSet.id, index, 4, reason.description)
decreaseRunningTasks(runningTasks)
- sched.listener.taskSetFailed(taskSet, errorMessage)
+ sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
// need to delete failed Taskset from schedule queue
sched.taskSetFinished(this)
}
@@ -184,7 +185,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
override def error(message: String) {
- sched.listener.taskSetFailed(taskSet, message)
+ sched.dagScheduler.taskSetFailed(taskSet, message)
sched.taskSetFinished(this)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index c7efc67a4a..7156d855d8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -32,7 +32,7 @@ private[spark] sealed abstract class BlockId {
def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
def isRDD = isInstanceOf[RDDBlockId]
def isShuffle = isInstanceOf[ShuffleBlockId]
- def isBroadcast = isInstanceOf[BroadcastBlockId]
+ def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]
override def toString = name
override def hashCode = name.hashCode
@@ -55,6 +55,10 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
def name = "broadcast_" + broadcastId
}
+private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
+ def name = broadcastId.name + "_" + hType
+}
+
private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
def name = "taskresult_" + taskId
}
@@ -72,6 +76,7 @@ private[spark] object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)".r
+ val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
val TEST = "test_(.*)".r
@@ -84,6 +89,8 @@ private[spark] object BlockId {
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case BROADCAST(broadcastId) =>
BroadcastBlockId(broadcastId.toLong)
+ case BROADCAST_HELPER(broadcastId, hType) =>
+ BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
case TASKRESULT(taskId) =>
TaskResultBlockId(taskId.toLong)
case STREAM(streamId, uniqueId) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 801f88a3db..c67a61515e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -21,6 +21,7 @@ import java.io.{InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
+import scala.util.Random
import akka.actor.{ActorSystem, Cancellable, Props}
import akka.dispatch.{Await, Future}
@@ -269,7 +270,7 @@ private[spark] class BlockManager(
}
/**
- * Actually send a UpdateBlockInfo message. Returns the mater's response,
+ * Actually send a UpdateBlockInfo message. Returns the master's response,
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
@@ -478,7 +479,7 @@ private[spark] class BlockManager(
}
logDebug("Getting remote block " + blockId)
// Get locations of block
- val locations = master.getLocations(blockId)
+ val locations = Random.shuffle(master.getLocations(blockId))
// Get block from remote locations
for (loc <- locations) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 633230c0a8..f8cf14b503 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -227,9 +227,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
- if (id.executorId == "<driver>" && !isLocal) {
- // Got a register message from the master node; don't register it
- } else if (!blockManagerInfo.contains(id)) {
+ if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(manager) =>
// A block manager of the same executor already exists.
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
index b3a53d928b..e022accee6 100644
--- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
@@ -20,8 +20,42 @@ package org.apache.spark
import org.scalatest.FunSuite
class BroadcastSuite extends FunSuite with LocalSparkContext {
-
- test("basic broadcast") {
+
+ override def afterEach() {
+ super.afterEach()
+ System.clearProperty("spark.broadcast.factory")
+ }
+
+ test("Using HttpBroadcast locally") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ sc = new SparkContext("local", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === Set((1, 10), (2, 10)))
+ }
+
+ test("Accessing HttpBroadcast variables from multiple threads") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ sc = new SparkContext("local[10]", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
+ }
+
+ test("Accessing HttpBroadcast variables in a local cluster") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ val numSlaves = 4
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
+ test("Using TorrentBroadcast locally") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
@@ -29,11 +63,23 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
assert(results.collect.toSet === Set((1, 10), (2, 10)))
}
- test("broadcast variables accessed in multiple threads") {
+ test("Accessing TorrentBroadcast variables from multiple threads") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
}
+
+ test("Accessing TorrentBroadcast variables in a local cluster") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
+ val numSlaves = 4
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index a192651491..d8a0e983b2 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark
import java.util.concurrent.Semaphore
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
import scala.concurrent.future
import scala.concurrent.ExecutionContext.Implicits.global
@@ -83,6 +85,36 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
assert(sc.parallelize(1 to 10, 2).count === 10)
}
+ test("job group") {
+ sc = new SparkContext("local[2]", "test")
+
+ // Add a listener to release the semaphore once any tasks are launched.
+ val sem = new Semaphore(0)
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem.release()
+ }
+ })
+
+ // jobA is the one to be cancelled.
+ val jobA = future {
+ sc.setJobGroup("jobA", "this is a job to be cancelled")
+ sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
+ }
+
+ sc.clearJobGroup()
+ val jobB = sc.parallelize(1 to 100, 2).countAsync()
+
+ // Block until both tasks of job A have started and cancel job A.
+ sem.acquire(2)
+ sc.cancelJobGroup("jobA")
+ val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) }
+ assert(e.getMessage contains "cancel")
+
+ // Once A is cancelled, job B should finish fairly quickly.
+ assert(jobB.get() === 100)
+ }
+
test("two jobs sharing the same stage") {
// sem1: make sure cancel is issued after some tasks are launched
// sem2: make sure the first stage is not finished until cancel is issued
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 838179c6b5..2a2f828be6 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -60,7 +60,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSets += taskSet
}
override def cancelTasks(stageId: Int) {}
- override def setListener(listener: TaskSchedulerListener) = {}
+ override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
override def defaultParallelism() = 2
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
index 80d0c5a5e9..b97f2b19b5 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
@@ -28,6 +28,30 @@ import org.apache.spark.executor.TaskMetrics
import java.nio.ByteBuffer
import org.apache.spark.util.{Utils, FakeClock}
+class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler(taskScheduler) {
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+ taskScheduler.startedTasks += taskInfo.index
+ }
+
+ override def taskEnded(
+ task: Task[_],
+ reason: TaskEndReason,
+ result: Any,
+ accumUpdates: mutable.Map[Long, Any],
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics) {
+ taskScheduler.endedTasks(taskInfo.index) = reason
+ }
+
+ override def executorGained(execId: String, host: String) {}
+
+ override def executorLost(execId: String) {}
+
+ override def taskSetFailed(taskSet: TaskSet, reason: String) {
+ taskScheduler.taskSetsFailed += taskSet.id
+ }
+}
+
/**
* A mock ClusterScheduler implementation that just remembers information about tasks started and
* feedback received from the TaskSetManagers. Note that it's important to initialize this with
@@ -44,30 +68,7 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*
val executors = new mutable.HashMap[String, String] ++ liveExecutors
- listener = new TaskSchedulerListener {
- def taskStarted(task: Task[_], taskInfo: TaskInfo) {
- startedTasks += taskInfo.index
- }
-
- def taskEnded(
- task: Task[_],
- reason: TaskEndReason,
- result: Any,
- accumUpdates: mutable.Map[Long, Any],
- taskInfo: TaskInfo,
- taskMetrics: TaskMetrics)
- {
- endedTasks(taskInfo.index) = reason
- }
-
- def executorGained(execId: String, host: String) {}
-
- def executorLost(execId: String) {}
-
- def taskSetFailed(taskSet: TaskSet, reason: String) {
- taskSetsFailed += taskSet.id
- }
- }
+ dagScheduler = new FakeDAGScheduler(this)
def removeExecutor(execId: String): Unit = executors -= execId
diff --git a/docs/configuration.md b/docs/configuration.md
index 3cf43ab305..97183bafdb 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -319,6 +319,14 @@ Apart from these, the following properties are also available, and may be useful
Should be greater than or equal to 1. Number of allowed retries = this value - 1.
</td>
</tr>
+<tr>
+ <td>spark.broadcast.blockSize</td>
+ <td>4096</td>
+ <td>
+ Size of each piece of a block in kilobytes for <code>TorrentBroadcastFactory</code>.
+ Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, <code>BlockManager</code> might take a performance hit.
+ </td>
+</tr>
</table>
diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
index 868ff81f67..529709c2f9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
@@ -22,12 +22,19 @@ import org.apache.spark.SparkContext
object BroadcastTest {
def main(args: Array[String]) {
if (args.length == 0) {
- System.err.println("Usage: BroadcastTest <master> [<slices>] [numElem]")
+ System.err.println("Usage: BroadcastTest <master> [slices] [numElem] [broadcastAlgo] [blockSize]")
System.exit(1)
}
- val sc = new SparkContext(args(0), "Broadcast Test",
+ val bcName = if (args.length > 3) args(3) else "Http"
+ val blockSize = if (args.length > 4) args(4) else "4096"
+
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast." + bcName + "BroadcastFactory")
+ System.setProperty("spark.broadcast.blockSize", blockSize)
+
+ val sc = new SparkContext(args(0), "Broadcast Test 2",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
+
val slices = if (args.length > 1) args(1).toInt else 2
val num = if (args.length > 2) args(2).toInt else 1000000
@@ -36,13 +43,15 @@ object BroadcastTest {
arr1(i) = i
}
- for (i <- 0 until 2) {
+ for (i <- 0 until 3) {
println("Iteration " + i)
println("===========")
+ val startTime = System.nanoTime
val barr1 = sc.broadcast(arr1)
sc.parallelize(1 to 10, slices).foreach {
i => println(barr1.value.size)
}
+ println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6))
}
System.exit(0)
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index f2bbe5358f..965c4f3a63 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -60,6 +60,8 @@ object SparkBuild extends Build {
lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings)
.dependsOn(core, bagel, mllib, repl, streaming) dependsOn(maybeYarn: _*)
+ lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects")
+
// A configuration to set an alternative publishLocalConfiguration
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
@@ -74,8 +76,11 @@ object SparkBuild extends Build {
// Conditionally include the yarn sub-project
lazy val maybeYarn = if(isYarnEnabled) Seq[ClasspathDependency](yarn) else Seq[ClasspathDependency]()
lazy val maybeYarnRef = if(isYarnEnabled) Seq[ProjectReference](yarn) else Seq[ProjectReference]()
- lazy val allProjects = Seq[ProjectReference](
- core, repl, examples, bagel, streaming, mllib, tools, assemblyProj) ++ maybeYarnRef
+
+ // Everything except assembly, tools and examples belong to packageProjects
+ lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib) ++ maybeYarnRef
+
+ lazy val allProjects = packageProjects ++ Seq[ProjectReference](examples, tools, assemblyProj)
def sharedSettings = Defaults.defaultSettings ++ Seq(
organization := "org.apache.spark",
@@ -303,7 +308,9 @@ object SparkBuild extends Build {
def assemblyProjSettings = sharedSettings ++ Seq(
name := "spark-assembly",
- jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" }
+ assembleDeps in Compile <<= (packageProjects.map(packageBin in Compile in _) ++ Seq(packageDependency in Compile)).dependOn,
+ jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" },
+ jarName in packageDependency <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + "-deps.jar" }
) ++ assemblySettings ++ extraAssemblySettings
def extraAssemblySettings() = Seq(
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 36f54a22cf..48a8fa9328 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -845,7 +845,14 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
val jars = Option(System.getenv("ADD_JARS")).map(_.split(','))
.getOrElse(new Array[String](0))
.map(new java.io.File(_).getAbsolutePath)
- sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars)
+ try {
+ sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars)
+ } catch {
+ case e: Exception =>
+ e.printStackTrace()
+ echo("Failed to create SparkContext, exiting...")
+ sys.exit(1)
+ }
sparkContext
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
index 6d6ef149cc..25da9aa917 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -22,7 +22,7 @@ import org.apache.spark.util.Utils
import org.apache.spark.scheduler.SplitInfo
import scala.collection
import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container}
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
import org.apache.hadoop.yarn.util.{RackResolver, Records}
import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
@@ -211,7 +211,7 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
val workerId = workerIdCounter.incrementAndGet().toString
val driverUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
- StandaloneSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
logInfo("launching container on " + containerId + " host " + workerHostname)
// just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but ..