aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala66
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/util/AkkaUtils.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala2
6 files changed, 54 insertions, 40 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 2279d77c91..b5fd334f40 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -19,25 +19,26 @@ package org.apache.spark.executor
import java.nio.ByteBuffer
-import akka.actor._
-import akka.remote._
+import scala.concurrent.Await
-import org.apache.spark.{SparkEnv, Logging, SecurityManager, SparkConf}
+import akka.actor.{Actor, ActorSelection, Props}
+import akka.pattern.Patterns
+import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent}
+
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.worker.WorkerWatcher
-import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.scheduler.TaskDescription
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{AkkaUtils, Utils}
private[spark] class CoarseGrainedExecutorBackend(
driverUrl: String,
executorId: String,
hostPort: String,
- cores: Int)
- extends Actor
- with ExecutorBackend
- with Logging {
+ cores: Int,
+ sparkProperties: Seq[(String, String)]) extends Actor with ExecutorBackend with Logging {
Utils.checkHostPort(hostPort, "Expected hostport")
@@ -52,7 +53,7 @@ private[spark] class CoarseGrainedExecutorBackend(
}
override def receive = {
- case RegisteredExecutor(sparkProperties) =>
+ case RegisteredExecutor =>
logInfo("Successfully registered with driver")
// Make this host instead of hostPort ?
executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties,
@@ -101,26 +102,33 @@ private[spark] object CoarseGrainedExecutorBackend {
workerUrl: Option[String]) {
SparkHadoopUtil.get.runAsSparkUser { () =>
- // Debug code
- Utils.checkHost(hostname)
-
- val conf = new SparkConf
- // Create a new ActorSystem to run the backend, because we can't create a
- // SparkEnv / Executor before getting started with all our system properties, etc
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0,
- conf, new SecurityManager(conf))
- // set it
- val sparkHostPort = hostname + ":" + boundPort
- actorSystem.actorOf(
- Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId,
- sparkHostPort, cores),
- name = "Executor")
- workerUrl.foreach {
- url =>
- actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
- }
- actorSystem.awaitTermination()
-
+ // Debug code
+ Utils.checkHost(hostname)
+
+ // Bootstrap to fetch the driver's Spark properties.
+ val executorConf = new SparkConf
+ val (fetcher, _) = AkkaUtils.createActorSystem(
+ "driverPropsFetcher", hostname, 0, executorConf, new SecurityManager(executorConf))
+ val driver = fetcher.actorSelection(driverUrl)
+ val timeout = AkkaUtils.askTimeout(executorConf)
+ val fut = Patterns.ask(driver, RetrieveSparkProps, timeout)
+ val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]]
+ fetcher.shutdown()
+
+ // Create a new ActorSystem using driver's Spark properties to run the backend.
+ val driverConf = new SparkConf().setAll(props)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
+ "sparkExecutor", hostname, 0, driverConf, new SecurityManager(driverConf))
+ // set it
+ val sparkHostPort = hostname + ":" + boundPort
+ actorSystem.actorOf(
+ Props(classOf[CoarseGrainedExecutorBackend],
+ driverUrl, executorId, sparkHostPort, cores, props),
+ name = "Executor")
+ workerUrl.foreach { url =>
+ actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
+ }
+ actorSystem.awaitTermination()
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index baee7a216a..557b9a3f46 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -212,7 +212,7 @@ private[spark] class Executor(
val serializedDirectResult = ser.serialize(directResult)
logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
val serializedResult = {
- if (serializedDirectResult.limit >= akkaFrameSize - 1024) {
+ if (serializedDirectResult.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
logInfo("Storing result for " + taskId + " in local BlockManager")
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index ca74069ef8..318e165522 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -20,21 +20,21 @@ package org.apache.spark.scheduler.cluster
import java.nio.ByteBuffer
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.TaskDescription
import org.apache.spark.util.{SerializableBuffer, Utils}
private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
private[spark] object CoarseGrainedClusterMessages {
+ case object RetrieveSparkProps extends CoarseGrainedClusterMessage
+
// Driver to executors
case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage
case class KillTask(taskId: Long, executor: String, interruptThread: Boolean)
extends CoarseGrainedClusterMessage
- case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
- extends CoarseGrainedClusterMessage
+ case object RegisteredExecutor extends CoarseGrainedClusterMessage
case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index e47a060683..05d01b0c82 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -75,7 +75,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
} else {
logInfo("Registered executor: " + sender + " with ID " + executorId)
- sender ! RegisteredExecutor(sparkProperties)
+ sender ! RegisteredExecutor
executorActor(executorId) = sender
executorHost(executorId) = Utils.parseHostPort(hostPort)._1
totalCores(executorId) = cores
@@ -124,6 +124,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
addressToExecutorId.get(address).foreach(removeExecutor(_,
"remote Akka client disassociated"))
+ case RetrieveSparkProps =>
+ sender ! sparkProperties
}
// Make fake resource offers on all executors
@@ -143,14 +145,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
for (task <- tasks.flatten) {
val ser = SparkEnv.get.closureSerializer.newInstance()
val serializedTask = ser.serialize(task)
- if (serializedTask.limit >= akkaFrameSize - 1024) {
+ if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
try {
- var msg = "Serialized task %s:%d was %d bytes which " +
- "exceeds spark.akka.frameSize (%d bytes). " +
- "Consider using broadcast variables for large values."
- msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize)
+ var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
+ "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
+ "spark.akka.frameSize or using broadcast variables for large values."
+ msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
+ AkkaUtils.reservedSizeBytes)
taskSet.abort(msg)
} catch {
case e: Exception => logError("Exception in error callback", e)
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index a8d12bb2a0..9930c71749 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -121,4 +121,7 @@ private[spark] object AkkaUtils extends Logging {
def maxFrameSizeBytes(conf: SparkConf): Int = {
conf.getInt("spark.akka.frameSize", 10) * 1024 * 1024
}
+
+ /** Space reserved for extra data in an Akka message besides serialized task or task result. */
+ val reservedSizeBytes = 200 * 1024
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
index efef9d26da..f77661ccbd 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
@@ -35,7 +35,7 @@ class CoarseGrainedSchedulerBackendSuite extends FunSuite with LocalSparkContext
val thrown = intercept[SparkException] {
larger.collect()
}
- assert(thrown.getMessage.contains("Consider using broadcast variables for large values"))
+ assert(thrown.getMessage.contains("using broadcast variables for large values"))
val smaller = sc.parallelize(1 to 4).collect()
assert(smaller.size === 4)
}