diff options
author | Prashant Sharma <prashant.s@imaginea.com> | 2013-07-03 11:43:26 +0530 |
---|---|---|
committer | Prashant Sharma <prashant.s@imaginea.com> | 2013-07-03 11:43:26 +0530 |
commit | a5f1f6a907b116325c56d38157ec2df76150951e (patch) | |
tree | 27de949c24a61b2301c7690db9e28992f49ea39c /core | |
parent | b7794813b181f13801596e8d8c3b4471c0c84f20 (diff) | |
parent | 6d60fe571a405eb9306a2be1817901316a46f892 (diff) | |
download | spark-a5f1f6a907b116325c56d38157ec2df76150951e.tar.gz spark-a5f1f6a907b116325c56d38157ec2df76150951e.tar.bz2 spark-a5f1f6a907b116325c56d38157ec2df76150951e.zip |
Merge branch 'master' into master-merge
Conflicts:
core/pom.xml
core/src/main/scala/spark/MapOutputTracker.scala
core/src/main/scala/spark/RDD.scala
core/src/main/scala/spark/RDDCheckpointData.scala
core/src/main/scala/spark/SparkContext.scala
core/src/main/scala/spark/Utils.scala
core/src/main/scala/spark/api/python/PythonRDD.scala
core/src/main/scala/spark/deploy/client/Client.scala
core/src/main/scala/spark/deploy/master/MasterWebUI.scala
core/src/main/scala/spark/deploy/worker/Worker.scala
core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
core/src/main/scala/spark/rdd/BlockRDD.scala
core/src/main/scala/spark/rdd/ZippedRDD.scala
core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
core/src/main/scala/spark/storage/BlockManager.scala
core/src/main/scala/spark/storage/BlockManagerMaster.scala
core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
core/src/main/scala/spark/storage/BlockManagerUI.scala
core/src/main/scala/spark/util/AkkaUtils.scala
core/src/test/scala/spark/SizeEstimatorSuite.scala
pom.xml
project/SparkBuild.scala
repl/src/main/scala/spark/repl/SparkILoop.scala
repl/src/test/scala/spark/repl/ReplSuite.scala
streaming/src/main/scala/spark/streaming/StreamingContext.scala
streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala
Diffstat (limited to 'core')
174 files changed, 9871 insertions, 2368 deletions
diff --git a/core/pom.xml b/core/pom.xml index 7f5cffc818..385663a638 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -32,8 +32,8 @@ <artifactId>compress-lzf</artifactId> </dependency> <dependency> - <groupId>asm</groupId> - <artifactId>asm-all</artifactId> + <groupId>org.ow2.asm</groupId> + <artifactId>asm</artifactId> </dependency> <dependency> <groupId>com.google.protobuf</groupId> @@ -85,18 +85,27 @@ </dependency> <dependency> <groupId>com.github.scala-incubator.io</groupId> - <artifactId>scala-io-file_${scala.version}</artifactId> + <artifactId>scala-io-file_2.10</artifactId> </dependency> <dependency> <groupId>org.apache.mesos</groupId> <artifactId>mesos</artifactId> </dependency> <dependency> + <groupId>io.netty</groupId> + <artifactId>netty-all</artifactId> + </dependency> + <dependency> <groupId>log4j</groupId> <artifactId>log4j</artifactId> </dependency> <dependency> + <groupId>org.apache.derby</groupId> + <artifactId>derby</artifactId> + <scope>test</scope> + </dependency> + <dependency> <groupId>org.scalatest</groupId> <artifactId>scalatest_${scala.version}</artifactId> <scope>test</scope> @@ -283,5 +292,72 @@ </plugins> </build> </profile> + <profile> + <id>hadoop2-yarn</id> + <dependencies> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-client</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-api</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-common</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-client</artifactId> + <scope>provided</scope> + </dependency> + </dependencies> + <build> + <plugins> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>build-helper-maven-plugin</artifactId> + <executions> + <execution> + <id>add-source</id> + <phase>generate-sources</phase> + <goals> + <goal>add-source</goal> + </goals> + <configuration> + <sources> + <source>src/main/scala</source> + <source>src/hadoop2-yarn/scala</source> + </sources> + </configuration> + </execution> + <execution> + <id>add-scala-test-sources</id> + <phase>generate-test-sources</phase> + <goals> + <goal>add-test-source</goal> + </goals> + <configuration> + <sources> + <source>src/test/scala</source> + </sources> + </configuration> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <configuration> + <classifier>hadoop2-yarn</classifier> + </configuration> + </plugin> + </plugins> + </build> + </profile> </profiles> </project> diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala index ca9f7219de..f286f2cf9c 100644 --- a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala +++ b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -4,4 +4,7 @@ trait HadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContext(conf, jobId) def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala index de7b0f81e3..264d421d14 100644 --- a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala +++ b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -6,4 +6,7 @@ trait HadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContext(conf, jobId) def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala new file mode 100644 index 0000000000..a0fb4fe25d --- /dev/null +++ b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala @@ -0,0 +1,23 @@ +package spark.deploy +import org.apache.hadoop.conf.Configuration + + +/** + * Contains util methods to interact with Hadoop from spark. + */ +object SparkHadoopUtil { + + def getUserNameFromEnvironment(): String = { + // defaulting to -D ... + System.getProperty("user.name") + } + + def runAsUser(func: (Product) => Unit, args: Product) { + + // Add support, if exists - for now, simply run func ! + func(args) + } + + // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems + def newConfiguration(): Configuration = new Configuration() +} diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala new file mode 100644 index 0000000000..875c0a220b --- /dev/null +++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -0,0 +1,13 @@ + +package org.apache.hadoop.mapred + +import org.apache.hadoop.mapreduce.TaskType + +trait HadoopMapRedUtil { + def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) + + def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = + new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId) +} diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala new file mode 100644 index 0000000000..8bc6fb6dea --- /dev/null +++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -0,0 +1,13 @@ +package org.apache.hadoop.mapreduce + +import org.apache.hadoop.conf.Configuration +import task.{TaskAttemptContextImpl, JobContextImpl} + +trait HadoopMapReduceUtil { + def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) + + def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = + new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId) +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala new file mode 100644 index 0000000000..ab1ab9d8a7 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala @@ -0,0 +1,63 @@ +package spark.deploy + +import collection.mutable.HashMap +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import java.security.PrivilegedExceptionAction + +/** + * Contains util methods to interact with Hadoop from spark. + */ +object SparkHadoopUtil { + + val yarnConf = newConfiguration() + + def getUserNameFromEnvironment(): String = { + // defaulting to env if -D is not present ... + val retval = System.getProperty(Environment.USER.name, System.getenv(Environment.USER.name)) + + // If nothing found, default to user we are running as + if (retval == null) System.getProperty("user.name") else retval + } + + def runAsUser(func: (Product) => Unit, args: Product) { + runAsUser(func, args, getUserNameFromEnvironment()) + } + + def runAsUser(func: (Product) => Unit, args: Product, user: String) { + + // println("running as user " + jobUserName) + + UserGroupInformation.setConfiguration(yarnConf) + val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(user) + appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] { + def run: AnyRef = { + func(args) + // no return value ... + null + } + }) + } + + // Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true. + def isYarnMode(): Boolean = { + val yarnMode = System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")) + java.lang.Boolean.valueOf(yarnMode) + } + + // Set an env variable indicating we are running in YARN mode. + // Note that anything with SPARK prefix gets propagated to all (remote) processes + def setYarnMode() { + System.setProperty("SPARK_YARN_MODE", "true") + } + + def setYarnMode(env: HashMap[String, String]) { + env("SPARK_YARN_MODE") = "true" + } + + // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems + // Always create a new config, dont reuse yarnConf. + def newConfiguration(): Configuration = new YarnConfiguration(new Configuration()) +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala new file mode 100644 index 0000000000..aa72c1e5fe --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala @@ -0,0 +1,329 @@ +package spark.deploy.yarn + +import java.net.Socket +import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import scala.collection.JavaConversions._ +import spark.{SparkContext, Logging, Utils} +import org.apache.hadoop.security.UserGroupInformation +import java.security.PrivilegedExceptionAction + +class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging { + + def this(args: ApplicationMasterArguments) = this(args, new Configuration()) + + private var rpc: YarnRPC = YarnRPC.create(conf) + private var resourceManager: AMRMProtocol = null + private var appAttemptId: ApplicationAttemptId = null + private var userThread: Thread = null + private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + private var yarnAllocator: YarnAllocationHandler = null + + def run() { + + // Initialization + val jobUserName = Utils.getUserNameFromEnvironment() + logInfo("running as user " + jobUserName) + + // run as user ... + UserGroupInformation.setConfiguration(yarnConf) + val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(jobUserName) + appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] { + def run: AnyRef = { + runImpl() + return null + } + }) + } + + private def runImpl() { + + appAttemptId = getApplicationAttemptId() + resourceManager = registerWithResourceManager() + val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() + + // Compute number of threads for akka + val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() + + if (minimumMemory > 0) { + val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) + + if (numCore > 0) { + // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 + // TODO: Uncomment when hadoop is on a version which has this fixed. + // args.workerCores = numCore + } + } + + // Workaround until hadoop moves to something which has + // https://issues.apache.org/jira/browse/HADOOP-8406 + // ignore result + // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times + // Hence args.workerCores = numCore disabled above. Any better option ? + // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) + + ApplicationMaster.register(this) + // Start the user's JAR + userThread = startUserClass() + + // This a bit hacky, but we need to wait until the spark.driver.port property has + // been set by the Thread executing the user class. + waitForSparkMaster() + + // Allocate all containers + allocateWorkers() + + // Wait for the user class to Finish + userThread.join() + + // Finish the ApplicationMaster + finishApplicationMaster() + // TODO: Exit based on success/failure + System.exit(0) + } + + private def getApplicationAttemptId(): ApplicationAttemptId = { + val envs = System.getenv() + val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) + val containerId = ConverterUtils.toContainerId(containerIdString) + val appAttemptId = containerId.getApplicationAttemptId() + logInfo("ApplicationAttemptId: " + appAttemptId) + return appAttemptId + } + + private def registerWithResourceManager(): AMRMProtocol = { + val rmAddress = NetUtils.createSocketAddr(yarnConf.get( + YarnConfiguration.RM_SCHEDULER_ADDRESS, + YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS)) + logInfo("Connecting to ResourceManager at " + rmAddress) + return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol] + } + + private def registerApplicationMaster(): RegisterApplicationMasterResponse = { + logInfo("Registering the ApplicationMaster") + val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest]) + .asInstanceOf[RegisterApplicationMasterRequest] + appMasterRequest.setApplicationAttemptId(appAttemptId) + // Setting this to master host,port - so that the ApplicationReport at client has some sensible info. + // Users can then monitor stderr/stdout on that node if required. + appMasterRequest.setHost(Utils.localHostName()) + appMasterRequest.setRpcPort(0) + // What do we provide here ? Might make sense to expose something sensible later ? + appMasterRequest.setTrackingUrl("") + return resourceManager.registerApplicationMaster(appMasterRequest) + } + + private def waitForSparkMaster() { + logInfo("Waiting for spark driver to be reachable.") + var driverUp = false + while(!driverUp) { + val driverHost = System.getProperty("spark.driver.host") + val driverPort = System.getProperty("spark.driver.port") + try { + val socket = new Socket(driverHost, driverPort.toInt) + socket.close() + logInfo("Master now available: " + driverHost + ":" + driverPort) + driverUp = true + } catch { + case e: Exception => + logError("Failed to connect to driver at " + driverHost + ":" + driverPort) + Thread.sleep(100) + } + } + } + + private def startUserClass(): Thread = { + logInfo("Starting the user JAR in a separate Thread") + val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader) + .getMethod("main", classOf[Array[String]]) + val t = new Thread { + override def run() { + // Copy + var mainArgs: Array[String] = new Array[String](args.userArgs.size()) + args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size()) + mainMethod.invoke(null, mainArgs) + } + } + t.start() + return t + } + + private def allocateWorkers() { + logInfo("Waiting for spark context initialization") + + try { + var sparkContext: SparkContext = null + ApplicationMaster.sparkContextRef.synchronized { + var count = 0 + while (ApplicationMaster.sparkContextRef.get() == null) { + logInfo("Waiting for spark context initialization ... " + count) + count = count + 1 + ApplicationMaster.sparkContextRef.wait(10000L) + } + sparkContext = ApplicationMaster.sparkContextRef.get() + assert(sparkContext != null) + this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, sparkContext.preferredNodeLocationData) + } + + + logInfo("Allocating " + args.numWorkers + " workers.") + // Wait until all containers have finished + // TODO: This is a bit ugly. Can we make it nicer? + // TODO: Handle container failure + while(yarnAllocator.getNumWorkersRunning < args.numWorkers && + // If user thread exists, then quit ! + userThread.isAlive) { + + this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0)) + ApplicationMaster.incrementAllocatorLoop(1) + Thread.sleep(100) + } + } finally { + // in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT : + // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks + ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) + } + logInfo("All workers have launched.") + + // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout + if (userThread.isAlive){ + // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. + + val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + // must be <= timeoutInterval/ 2. + // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM. + // so atleast 1 minute or timeoutInterval / 10 - whichever is higher. + val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L)) + launchReporterThread(interval) + } + } + + // TODO: We might want to extend this to allocate more containers in case they die ! + private def launchReporterThread(_sleepTime: Long): Thread = { + val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime + + val t = new Thread { + override def run() { + while (userThread.isAlive){ + val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning + if (missingWorkerCount > 0) { + logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers") + yarnAllocator.allocateContainers(missingWorkerCount) + } + else sendProgress() + Thread.sleep(sleepTime) + } + } + } + // setting to daemon status, though this is usually not a good idea. + t.setDaemon(true) + t.start() + logInfo("Started progress reporter thread - sleep time : " + sleepTime) + return t + } + + private def sendProgress() { + logDebug("Sending progress") + // simulated with an allocate request with no nodes requested ... + yarnAllocator.allocateContainers(0) + } + + /* + def printContainers(containers: List[Container]) = { + for (container <- containers) { + logInfo("Launching shell command on a new container." + + ", containerId=" + container.getId() + + ", containerNode=" + container.getNodeId().getHost() + + ":" + container.getNodeId().getPort() + + ", containerNodeURI=" + container.getNodeHttpAddress() + + ", containerState" + container.getState() + + ", containerResourceMemory" + + container.getResource().getMemory()) + } + } + */ + + def finishApplicationMaster() { + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(appAttemptId) + // TODO: Check if the application has failed or succeeded + finishReq.setFinishApplicationStatus(FinalApplicationStatus.SUCCEEDED) + resourceManager.finishApplicationMaster(finishReq) + } + +} + +object ApplicationMaster { + // number of times to wait for the allocator loop to complete. + // each loop iteration waits for 100ms, so maximum of 3 seconds. + // This is to ensure that we have reasonable number of containers before we start + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more + // containers are available. Might need to handle this better. + private val ALLOCATOR_LOOP_WAIT_COUNT = 30 + def incrementAllocatorLoop(by: Int) { + val count = yarnAllocatorLoop.getAndAdd(by) + if (count >= ALLOCATOR_LOOP_WAIT_COUNT){ + yarnAllocatorLoop.synchronized { + // to wake threads off wait ... + yarnAllocatorLoop.notifyAll() + } + } + } + + private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]() + + def register(master: ApplicationMaster) { + applicationMasters.add(master) + } + + val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null) + val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0) + + def sparkContextInitialized(sc: SparkContext): Boolean = { + var modified = false + sparkContextRef.synchronized { + modified = sparkContextRef.compareAndSet(null, sc) + sparkContextRef.notifyAll() + } + + // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit + // Should not really have to do this, but it helps yarn to evict resources earlier. + // not to mention, prevent Client declaring failure even though we exit'ed properly. + if (modified) { + Runtime.getRuntime().addShutdownHook(new Thread with Logging { + // This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run' + logInfo("Adding shutdown hook for context " + sc) + override def run() { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() + // best case ... + for (master <- applicationMasters) master.finishApplicationMaster + } + } ) + } + + // Wait for initialization to complete and atleast 'some' nodes can get allocated + yarnAllocatorLoop.synchronized { + while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT){ + yarnAllocatorLoop.wait(1000L) + } + } + modified + } + + def main(argStrings: Array[String]) { + val args = new ApplicationMasterArguments(argStrings) + new ApplicationMaster(args).run() + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala new file mode 100644 index 0000000000..1b00208511 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -0,0 +1,77 @@ +package spark.deploy.yarn + +import spark.util.IntParam +import collection.mutable.ArrayBuffer + +class ApplicationMasterArguments(val args: Array[String]) { + var userJar: String = null + var userClass: String = null + var userArgs: Seq[String] = Seq[String]() + var workerMemory = 1024 + var workerCores = 1 + var numWorkers = 2 + + parseArgs(args.toList) + + private def parseArgs(inputArgs: List[String]): Unit = { + val userArgsBuffer = new ArrayBuffer[String]() + + var args = inputArgs + + while (! args.isEmpty) { + + args match { + case ("--jar") :: value :: tail => + userJar = value + args = tail + + case ("--class") :: value :: tail => + userClass = value + args = tail + + case ("--args") :: value :: tail => + userArgsBuffer += value + args = tail + + case ("--num-workers") :: IntParam(value) :: tail => + numWorkers = value + args = tail + + case ("--worker-memory") :: IntParam(value) :: tail => + workerMemory = value + args = tail + + case ("--worker-cores") :: IntParam(value) :: tail => + workerCores = value + args = tail + + case Nil => + if (userJar == null || userClass == null) { + printUsageAndExit(1) + } + + case _ => + printUsageAndExit(1, args) + } + } + + userArgs = userArgsBuffer.readOnly + } + + def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + if (unknownParam != null) { + System.err.println("Unknown/unsupported param " + unknownParam) + } + System.err.println( + "Usage: spark.deploy.yarn.ApplicationMaster [options] \n" + + "Options:\n" + + " --jar JAR_PATH Path to your application's JAR file (required)\n" + + " --class CLASS_NAME Name of your application's main class (required)\n" + + " --args ARGS Arguments to be passed to your application's main class.\n" + + " Mutliple invocations are possible, each will be passed in order.\n" + + " --num-workers NUM Number of workers to start (Default: 2)\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1)\n" + + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n") + System.exit(exitCode) + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala new file mode 100644 index 0000000000..7a881e26df --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala @@ -0,0 +1,272 @@ +package spark.deploy.yarn + +import java.net.{InetSocketAddress, URI} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.client.YarnClientImpl +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ +import spark.{Logging, Utils} +import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils} +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import spark.deploy.SparkHadoopUtil + +class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging { + + def this(args: ClientArguments) = this(new Configuration(), args) + + var rpc: YarnRPC = YarnRPC.create(conf) + val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + def run() { + init(yarnConf) + start() + logClusterResourceDetails() + + val newApp = super.getNewApplication() + val appId = newApp.getApplicationId() + + verifyClusterResources(newApp) + val appContext = createApplicationSubmissionContext(appId) + val localResources = prepareLocalResources(appId, "spark") + val env = setupLaunchEnv(localResources) + val amContainer = createContainerLaunchContext(newApp, localResources, env) + + appContext.setQueue(args.amQueue) + appContext.setAMContainerSpec(amContainer) + appContext.setUser(args.amUser) + + submitApp(appContext) + + monitorApplication(appId) + System.exit(0) + } + + + def logClusterResourceDetails() { + val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics + logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers) + + val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue) + logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity + + ", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size + + ", queueChildQueueCount=" + queueInfo.getChildQueues.size) + } + + + def verifyClusterResources(app: GetNewApplicationResponse) = { + val maxMem = app.getMaximumResourceCapability().getMemory() + logInfo("Max mem capabililty of resources in this cluster " + maxMem) + + // If the cluster does not have enough memory resources, exit. + val requestedMem = (args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + args.numWorkers * args.workerMemory + if (requestedMem > maxMem) { + logError("Cluster cannot satisfy memory resource request of " + requestedMem) + System.exit(1) + } + } + + def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = { + logInfo("Setting up application submission context for ASM") + val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) + appContext.setApplicationId(appId) + appContext.setApplicationName("Spark") + return appContext + } + + def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = { + logInfo("Preparing Local resources") + val locaResources = HashMap[String, LocalResource]() + // Upload Spark and the application JAR to the remote file system + // Add them as local resources to the AM + val fs = FileSystem.get(conf) + Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF")) + .foreach { case(destName, _localPath) => + val localPath: String = if (_localPath != null) _localPath.trim() else "" + if (! localPath.isEmpty()) { + val src = new Path(localPath) + val pathSuffix = appName + "/" + appId.getId() + destName + val dst = new Path(fs.getHomeDirectory(), pathSuffix) + logInfo("Uploading " + src + " to " + dst) + fs.copyFromLocalFile(false, true, src, dst) + val destStatus = fs.getFileStatus(dst) + + val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + amJarRsrc.setType(LocalResourceType.FILE) + amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION) + amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst)) + amJarRsrc.setTimestamp(destStatus.getModificationTime()) + amJarRsrc.setSize(destStatus.getLen()) + locaResources(destName) = amJarRsrc + } + } + return locaResources + } + + def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = { + logInfo("Setting up the launch environment") + val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null) + + val env = new HashMap[String, String]() + Apps.addToEnvironment(env, Environment.USER.name, args.amUser) + + // If log4j present, ensure ours overrides all others + if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./") + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH") + Client.populateHadoopClasspath(yarnConf, env) + SparkHadoopUtil.setYarnMode(env) + env("SPARK_YARN_JAR_PATH") = + localResources("spark.jar").getResource().getScheme.toString() + "://" + + localResources("spark.jar").getResource().getFile().toString() + env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString() + env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString() + + env("SPARK_YARN_USERJAR_PATH") = + localResources("app.jar").getResource().getScheme.toString() + "://" + + localResources("app.jar").getResource().getFile().toString() + env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString() + env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString() + + if (log4jConfLocalRes != null) { + env("SPARK_YARN_LOG4J_PATH") = + log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString() + env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString() + env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString() + } + + // Add each SPARK-* key to the environment + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + return env + } + + def userArgsToString(clientArgs: ClientArguments): String = { + val prefix = " --args " + val args = clientArgs.userArgs + val retval = new StringBuilder() + for (arg <- args){ + retval.append(prefix).append(" '").append(arg).append("' ") + } + + retval.toString + } + + def createContainerLaunchContext(newApp: GetNewApplicationResponse, + localResources: HashMap[String, LocalResource], + env: HashMap[String, String]): ContainerLaunchContext = { + logInfo("Setting up container launch context") + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) + amContainer.setLocalResources(localResources) + amContainer.setEnvironment(env) + + val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory() + + var amMemory = ((args.amMemory / minResMemory) * minResMemory) + + (if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD + + // Extra options for the JVM + var JAVA_OPTS = "" + + // Add Xmx for am memory + JAVA_OPTS += "-Xmx" + amMemory + "m " + + // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out. + // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same + // node, spark gc effects all other containers performance (which can also be other spark containers) + // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is + // limited to subset of cores on a node. + if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) { + // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines + JAVA_OPTS += " -XX:+UseConcMarkSweepGC " + JAVA_OPTS += " -XX:+CMSIncrementalMode " + JAVA_OPTS += " -XX:+CMSIncrementalPacing " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + } + if (env.isDefinedAt("SPARK_JAVA_OPTS")) { + JAVA_OPTS += env("SPARK_JAVA_OPTS") + " " + } + + // Command for the ApplicationMaster + val commands = List[String]("java " + + " -server " + + JAVA_OPTS + + " spark.deploy.yarn.ApplicationMaster" + + " --class " + args.userClass + + " --jar " + args.userJar + + userArgsToString(args) + + " --worker-memory " + args.workerMemory + + " --worker-cores " + args.workerCores + + " --num-workers " + args.numWorkers + + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + logInfo("Command for the ApplicationMaster: " + commands(0)) + amContainer.setCommands(commands) + + val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] + // Memory for the ApplicationMaster + capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + amContainer.setResource(capability) + + return amContainer + } + + def submitApp(appContext: ApplicationSubmissionContext) = { + // Submit the application to the applications manager + logInfo("Submitting application to ASM") + super.submitApplication(appContext) + } + + def monitorApplication(appId: ApplicationId): Boolean = { + while(true) { + Thread.sleep(1000) + val report = super.getApplicationReport(appId) + + logInfo("Application report from ASM: \n" + + "\t application identifier: " + appId.toString() + "\n" + + "\t appId: " + appId.getId() + "\n" + + "\t clientToken: " + report.getClientToken() + "\n" + + "\t appDiagnostics: " + report.getDiagnostics() + "\n" + + "\t appMasterHost: " + report.getHost() + "\n" + + "\t appQueue: " + report.getQueue() + "\n" + + "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + + "\t appStartTime: " + report.getStartTime() + "\n" + + "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + + "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" + + "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" + + "\t appUser: " + report.getUser() + ) + + val state = report.getYarnApplicationState() + val dsStatus = report.getFinalApplicationStatus() + if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + return true + } + } + return true + } +} + +object Client { + def main(argStrings: Array[String]) { + val args = new ClientArguments(argStrings) + SparkHadoopUtil.setYarnMode() + new Client(args).run + } + + // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps + def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) { + for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim) + } + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala new file mode 100644 index 0000000000..24110558e7 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala @@ -0,0 +1,105 @@ +package spark.deploy.yarn + +import spark.util.MemoryParam +import spark.util.IntParam +import collection.mutable.{ArrayBuffer, HashMap} +import spark.scheduler.{InputFormatInfo, SplitInfo} + +// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware ! +class ClientArguments(val args: Array[String]) { + var userJar: String = null + var userClass: String = null + var userArgs: Seq[String] = Seq[String]() + var workerMemory = 1024 + var workerCores = 1 + var numWorkers = 2 + var amUser = System.getProperty("user.name") + var amQueue = System.getProperty("QUEUE", "default") + var amMemory: Int = 512 + // TODO + var inputFormatInfo: List[InputFormatInfo] = null + + parseArgs(args.toList) + + private def parseArgs(inputArgs: List[String]): Unit = { + val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]() + val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]() + + var args = inputArgs + + while (! args.isEmpty) { + + args match { + case ("--jar") :: value :: tail => + userJar = value + args = tail + + case ("--class") :: value :: tail => + userClass = value + args = tail + + case ("--args") :: value :: tail => + userArgsBuffer += value + args = tail + + case ("--master-memory") :: MemoryParam(value) :: tail => + amMemory = value + args = tail + + case ("--num-workers") :: IntParam(value) :: tail => + numWorkers = value + args = tail + + case ("--worker-memory") :: MemoryParam(value) :: tail => + workerMemory = value + args = tail + + case ("--worker-cores") :: IntParam(value) :: tail => + workerCores = value + args = tail + + case ("--user") :: value :: tail => + amUser = value + args = tail + + case ("--queue") :: value :: tail => + amQueue = value + args = tail + + case Nil => + if (userJar == null || userClass == null) { + printUsageAndExit(1) + } + + case _ => + printUsageAndExit(1, args) + } + } + + userArgs = userArgsBuffer.readOnly + inputFormatInfo = inputFormatMap.values.toList + } + + + def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + if (unknownParam != null) { + System.err.println("Unknown/unsupported param " + unknownParam) + } + System.err.println( + "Usage: spark.deploy.yarn.Client [options] \n" + + "Options:\n" + + " --jar JAR_PATH Path to your application's JAR file (required)\n" + + " --class CLASS_NAME Name of your application's main class (required)\n" + + " --args ARGS Arguments to be passed to your application's main class.\n" + + " Mutliple invocations are possible, each will be passed in order.\n" + + " --num-workers NUM Number of workers to start (Default: 2)\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" + + " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" + + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" + + " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" + + " --user USERNAME Run the ApplicationMaster (and slaves) as a different user\n" + ) + System.exit(exitCode) + } + +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala new file mode 100644 index 0000000000..a2bf0af762 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala @@ -0,0 +1,171 @@ +package spark.deploy.yarn + +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment + +import scala.collection.JavaConversions._ +import scala.collection.mutable.HashMap + +import spark.{Logging, Utils} + +class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String, + slaveId: String, hostname: String, workerMemory: Int, workerCores: Int) + extends Runnable with Logging { + + var rpc: YarnRPC = YarnRPC.create(conf) + var cm: ContainerManager = null + val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + def run = { + logInfo("Starting Worker Container") + cm = connectToCM + startContainer + } + + def startContainer = { + logInfo("Setting up ContainerLaunchContext") + + val ctx = Records.newRecord(classOf[ContainerLaunchContext]) + .asInstanceOf[ContainerLaunchContext] + + ctx.setContainerId(container.getId()) + ctx.setResource(container.getResource()) + val localResources = prepareLocalResources + ctx.setLocalResources(localResources) + + val env = prepareEnvironment + ctx.setEnvironment(env) + + // Extra options for the JVM + var JAVA_OPTS = "" + // Set the JVM memory + val workerMemoryString = workerMemory + "m" + JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " " + if (env.isDefinedAt("SPARK_JAVA_OPTS")) { + JAVA_OPTS += env("SPARK_JAVA_OPTS") + " " + } + // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out. + // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same + // node, spark gc effects all other containers performance (which can also be other spark containers) + // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is + // limited to subset of cores on a node. +/* + else { + // If no java_opts specified, default to using -XX:+CMSIncrementalMode + // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont want to mess with it. + // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tennent machines + // The options are based on + // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline + JAVA_OPTS += " -XX:+UseConcMarkSweepGC " + JAVA_OPTS += " -XX:+CMSIncrementalMode " + JAVA_OPTS += " -XX:+CMSIncrementalPacing " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + } +*/ + + ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName()) + val commands = List[String]("java " + + " -server " + + // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. + // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in an inconsistent state. + // TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ? + " -XX:OnOutOfMemoryError='kill %p' " + + JAVA_OPTS + + " spark.executor.StandaloneExecutorBackend " + + masterAddress + " " + + slaveId + " " + + hostname + " " + + workerCores + + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + logInfo("Setting up worker with commands: " + commands) + ctx.setCommands(commands) + + // Send the start request to the ContainerManager + val startReq = Records.newRecord(classOf[StartContainerRequest]) + .asInstanceOf[StartContainerRequest] + startReq.setContainerLaunchContext(ctx) + cm.startContainer(startReq) + } + + + def prepareLocalResources: HashMap[String, LocalResource] = { + logInfo("Preparing Local resources") + val locaResources = HashMap[String, LocalResource]() + + // Spark JAR + val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + sparkJarResource.setType(LocalResourceType.FILE) + sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION) + sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI( + new URI(System.getenv("SPARK_YARN_JAR_PATH")))) + sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong) + sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong) + locaResources("spark.jar") = sparkJarResource + // User JAR + val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + userJarResource.setType(LocalResourceType.FILE) + userJarResource.setVisibility(LocalResourceVisibility.APPLICATION) + userJarResource.setResource(ConverterUtils.getYarnUrlFromURI( + new URI(System.getenv("SPARK_YARN_USERJAR_PATH")))) + userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong) + userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong) + locaResources("app.jar") = userJarResource + + // Log4j conf - if available + if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) { + val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + log4jConfResource.setType(LocalResourceType.FILE) + log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION) + log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI( + new URI(System.getenv("SPARK_YARN_LOG4J_PATH")))) + log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong) + log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong) + locaResources("log4j.properties") = log4jConfResource + } + + + logInfo("Prepared Local resources " + locaResources) + return locaResources + } + + def prepareEnvironment: HashMap[String, String] = { + val env = new HashMap[String, String]() + // should we add this ? + Apps.addToEnvironment(env, Environment.USER.name, Utils.getUserNameFromEnvironment()) + + // If log4j present, ensure ours overrides all others + if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) { + // Which is correct ? + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./") + } + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH") + Client.populateHadoopClasspath(yarnConf, env) + + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + return env + } + + def connectToCM: ContainerManager = { + val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort() + val cmAddress = NetUtils.createSocketAddr(cmHostPortStr) + logInfo("Connecting to ContainerManager at " + cmHostPortStr) + return rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager] + } + +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala new file mode 100644 index 0000000000..61dd72a651 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala @@ -0,0 +1,547 @@ +package spark.deploy.yarn + +import spark.{Logging, Utils} +import spark.scheduler.SplitInfo +import scala.collection +import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container} +import spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend} +import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse} +import org.apache.hadoop.yarn.util.{RackResolver, Records} +import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap} +import java.util.concurrent.atomic.AtomicInteger +import org.apache.hadoop.yarn.api.AMRMProtocol +import collection.JavaConversions._ +import collection.mutable.{ArrayBuffer, HashMap, HashSet} +import org.apache.hadoop.conf.Configuration +import java.util.{Collections, Set => JSet} +import java.lang.{Boolean => JBoolean} + +object AllocationType extends Enumeration ("HOST", "RACK", "ANY") { + type AllocationType = Value + val HOST, RACK, ANY = Value +} + +// too many params ? refactor it 'somehow' ? +// needs to be mt-safe +// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive : should make it +// more proactive and decoupled. +// Note that right now, we assume all node asks as uniform in terms of capabilities and priority +// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for more info +// on how we are requesting for containers. +private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceManager: AMRMProtocol, + val appAttemptId: ApplicationAttemptId, + val maxWorkers: Int, val workerMemory: Int, val workerCores: Int, + val preferredHostToCount: Map[String, Int], + val preferredRackToCount: Map[String, Int]) + extends Logging { + + + // These three are locked on allocatedHostToContainersMap. Complementary data structures + // allocatedHostToContainersMap : containers which are running : host, Set<containerid> + // allocatedContainerToHostMap: container to host mapping + private val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]() + private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() + // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an allocated node) + // As with the two data structures above, tightly coupled with them, and to be locked on allocatedHostToContainersMap + private val allocatedRackCount = new HashMap[String, Int]() + + // containers which have been released. + private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]() + // containers to be released in next request to RM + private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] + + private val numWorkersRunning = new AtomicInteger() + // Used to generate a unique id per worker + private val workerIdCounter = new AtomicInteger() + private val lastResponseId = new AtomicInteger() + + def getNumWorkersRunning: Int = numWorkersRunning.intValue + + + def isResourceConstraintSatisfied(container: Container): Boolean = { + container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + } + + def allocateContainers(workersToRequest: Int) { + // We need to send the request only once from what I understand ... but for now, not modifying this much. + + // Keep polling the Resource Manager for containers + val amResp = allocateWorkerResources(workersToRequest).getAMResponse + + val _allocatedContainers = amResp.getAllocatedContainers() + if (_allocatedContainers.size > 0) { + + + logDebug("Allocated " + _allocatedContainers.size + " containers, current count " + + numWorkersRunning.get() + ", to-be-released " + releasedContainerList + + ", pendingReleaseContainers : " + pendingReleaseContainers) + logDebug("Cluster Resources: " + amResp.getAvailableResources) + + val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() + + // ignore if not satisfying constraints { + for (container <- _allocatedContainers) { + if (isResourceConstraintSatisfied(container)) { + // allocatedContainers += container + + val host = container.getNodeId.getHost + val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]()) + + containers += container + } + // Add all ignored containers to released list + else releasedContainerList.add(container.getId()) + } + + // Find the appropriate containers to use + // Slightly non trivial groupBy I guess ... + val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() + + for (candidateHost <- hostToContainers.keySet) + { + val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) + val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) + + var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null) + assert(remainingContainers != null) + + if (requiredHostCount >= remainingContainers.size){ + // Since we got <= required containers, add all to dataLocalContainers + dataLocalContainers.put(candidateHost, remainingContainers) + // all consumed + remainingContainers = null + } + else if (requiredHostCount > 0) { + // container list has more containers than we need for data locality. + // Split into two : data local container count of (remainingContainers.size - requiredHostCount) + // and rest as remainingContainer + val (dataLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredHostCount) + dataLocalContainers.put(candidateHost, dataLocal) + // remainingContainers = remaining + + // yarn has nasty habit of allocating a tonne of containers on a host - discourage this : + // add remaining to release list. If we have insufficient containers, next allocation cycle + // will reallocate (but wont treat it as data local) + for (container <- remaining) releasedContainerList.add(container.getId()) + remainingContainers = null + } + + // now rack local + if (remainingContainers != null){ + val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) + + if (rack != null){ + val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) + val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - + rackLocalContainers.get(rack).getOrElse(List()).size + + + if (requiredRackCount >= remainingContainers.size){ + // Add all to dataLocalContainers + dataLocalContainers.put(rack, remainingContainers) + // all consumed + remainingContainers = null + } + else if (requiredRackCount > 0) { + // container list has more containers than we need for data locality. + // Split into two : data local container count of (remainingContainers.size - requiredRackCount) + // and rest as remainingContainer + val (rackLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredRackCount) + val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]()) + + existingRackLocal ++= rackLocal + remainingContainers = remaining + } + } + } + + // If still not consumed, then it is off rack host - add to that list. + if (remainingContainers != null){ + offRackContainers.put(candidateHost, remainingContainers) + } + } + + // Now that we have split the containers into various groups, go through them in order : + // first host local, then rack local and then off rack (everything else). + // Note that the list we create below tries to ensure that not all containers end up within a host + // if there are sufficiently large number of hosts/containers. + + val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size) + allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers) + allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers) + allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers) + + // Run each of the allocated containers + for (container <- allocatedContainers) { + val numWorkersRunningNow = numWorkersRunning.incrementAndGet() + val workerHostname = container.getNodeId.getHost + val containerId = container.getId + + assert (container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)) + + if (numWorkersRunningNow > maxWorkers) { + logInfo("Ignoring container " + containerId + " at host " + workerHostname + + " .. we already have required number of containers") + releasedContainerList.add(containerId) + // reset counter back to old value. + numWorkersRunning.decrementAndGet() + } + else { + // deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter) + val workerId = workerIdCounter.incrementAndGet().toString + val driverUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), + StandaloneSchedulerBackend.ACTOR_NAME) + + logInfo("launching container on " + containerId + " host " + workerHostname) + // just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but .. + pendingReleaseContainers.remove(containerId) + + val rack = YarnAllocationHandler.lookupRack(conf, workerHostname) + allocatedHostToContainersMap.synchronized { + val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, new HashSet[ContainerId]()) + + containerSet += containerId + allocatedContainerToHostMap.put(containerId, workerHostname) + if (rack != null) allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) + } + + new Thread( + new WorkerRunnable(container, conf, driverUrl, workerId, + workerHostname, workerMemory, workerCores) + ).start() + } + } + logDebug("After allocated " + allocatedContainers.size + " containers (orig : " + + _allocatedContainers.size + "), current count " + numWorkersRunning.get() + + ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers) + } + + + val completedContainers = amResp.getCompletedContainersStatuses() + if (completedContainers.size > 0){ + logDebug("Completed " + completedContainers.size + " containers, current count " + numWorkersRunning.get() + + ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers) + + for (completedContainer <- completedContainers){ + val containerId = completedContainer.getContainerId + + // Was this released by us ? If yes, then simply remove from containerSet and move on. + if (pendingReleaseContainers.containsKey(containerId)) { + pendingReleaseContainers.remove(containerId) + } + else { + // simply decrement count - next iteration of ReporterThread will take care of allocating ! + numWorkersRunning.decrementAndGet() + logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState + + " httpaddress: " + completedContainer.getDiagnostics) + } + + allocatedHostToContainersMap.synchronized { + if (allocatedContainerToHostMap.containsKey(containerId)) { + val host = allocatedContainerToHostMap.get(containerId).getOrElse(null) + assert (host != null) + + val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null) + assert (containerSet != null) + + containerSet -= containerId + if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host) + else allocatedHostToContainersMap.update(host, containerSet) + + allocatedContainerToHostMap -= containerId + + // doing this within locked context, sigh ... move to outside ? + val rack = YarnAllocationHandler.lookupRack(conf, host) + if (rack != null) { + val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 + if (rackCount > 0) allocatedRackCount.put(rack, rackCount) + else allocatedRackCount.remove(rack) + } + } + } + } + logDebug("After completed " + completedContainers.size + " containers, current count " + + numWorkersRunning.get() + ", to-be-released " + releasedContainerList + + ", pendingReleaseContainers : " + pendingReleaseContainers) + } + } + + def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = { + // First generate modified racks and new set of hosts under it : then issue requests + val rackToCounts = new HashMap[String, Int]() + + // Within this lock - used to read/write to the rack related maps too. + for (container <- hostContainers) { + val candidateHost = container.getHostName + val candidateNumContainers = container.getNumContainers + assert(YarnAllocationHandler.ANY_HOST != candidateHost) + + val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) + if (rack != null) { + var count = rackToCounts.getOrElse(rack, 0) + count += candidateNumContainers + rackToCounts.put(rack, count) + } + } + + val requestedContainers: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](rackToCounts.size) + for ((rack, count) <- rackToCounts){ + requestedContainers += + createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY) + } + + requestedContainers.toList + } + + def allocatedContainersOnHost(host: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedHostToContainersMap.getOrElse(host, Set()).size + } + retval + } + + def allocatedContainersOnRack(rack: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedRackCount.getOrElse(rack, 0) + } + retval + } + + private def allocateWorkerResources(numWorkers: Int): AllocateResponse = { + + var resourceRequests: List[ResourceRequest] = null + + // default. + if (numWorkers <= 0 || preferredHostToCount.isEmpty) { + logDebug("numWorkers: " + numWorkers + ", host preferences ? " + preferredHostToCount.isEmpty) + resourceRequests = List( + createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)) + } + else { + // request for all hosts in preferred nodes and for numWorkers - + // candidates.size, request by default allocation policy. + val hostContainerRequests: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](preferredHostToCount.size) + for ((candidateHost, candidateCount) <- preferredHostToCount) { + val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost) + + if (requiredCount > 0) { + hostContainerRequests += + createResourceRequest(AllocationType.HOST, candidateHost, requiredCount, YarnAllocationHandler.PRIORITY) + } + } + val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(hostContainerRequests.toList) + + val anyContainerRequests: ResourceRequest = + createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY) + + val containerRequests: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](hostContainerRequests.size() + rackContainerRequests.size() + 1) + + containerRequests ++= hostContainerRequests + containerRequests ++= rackContainerRequests + containerRequests += anyContainerRequests + + resourceRequests = containerRequests.toList + } + + val req = Records.newRecord(classOf[AllocateRequest]) + req.setResponseId(lastResponseId.incrementAndGet) + req.setApplicationAttemptId(appAttemptId) + + req.addAllAsks(resourceRequests) + + val releasedContainerList = createReleasedContainerList() + req.addAllReleases(releasedContainerList) + + + + if (numWorkers > 0) { + logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.") + } + else { + logDebug("Empty allocation req .. release : " + releasedContainerList) + } + + for (req <- resourceRequests) { + logInfo("rsrcRequest ... host : " + req.getHostName + ", numContainers : " + req.getNumContainers + + ", p = " + req.getPriority().getPriority + ", capability: " + req.getCapability) + } + resourceManager.allocate(req) + } + + + private def createResourceRequest(requestType: AllocationType.AllocationType, + resource:String, numWorkers: Int, priority: Int): ResourceRequest = { + + // If hostname specified, we need atleast two requests - node local and rack local. + // There must be a third request - which is ANY : that will be specially handled. + requestType match { + case AllocationType.HOST => { + assert (YarnAllocationHandler.ANY_HOST != resource) + + val hostname = resource + val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority) + + // add to host->rack mapping + YarnAllocationHandler.populateRackInfo(conf, hostname) + + nodeLocal + } + + case AllocationType.RACK => { + val rack = resource + createResourceRequestImpl(rack, numWorkers, priority) + } + + case AllocationType.ANY => { + createResourceRequestImpl(YarnAllocationHandler.ANY_HOST, numWorkers, priority) + } + + case _ => throw new IllegalArgumentException("Unexpected/unsupported request type .. " + requestType) + } + } + + private def createResourceRequestImpl(hostname:String, numWorkers: Int, priority: Int): ResourceRequest = { + + val rsrcRequest = Records.newRecord(classOf[ResourceRequest]) + val memCapability = Records.newRecord(classOf[Resource]) + // There probably is some overhead here, let's reserve a bit more memory. + memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + rsrcRequest.setCapability(memCapability) + + val pri = Records.newRecord(classOf[Priority]) + pri.setPriority(priority) + rsrcRequest.setPriority(pri) + + rsrcRequest.setHostName(hostname) + + rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0)) + rsrcRequest + } + + def createReleasedContainerList(): ArrayBuffer[ContainerId] = { + + val retval = new ArrayBuffer[ContainerId](1) + // iterator on COW list ... + for (container <- releasedContainerList.iterator()){ + retval += container + } + // remove from the original list. + if (! retval.isEmpty) { + releasedContainerList.removeAll(retval) + for (v <- retval) pendingReleaseContainers.put(v, true) + logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " + + pendingReleaseContainers) + } + + retval + } +} + +object YarnAllocationHandler { + + val ANY_HOST = "*" + // all requests are issued with same priority : we do not (yet) have any distinction between + // request types (like map/reduce in hadoop for example) + val PRIORITY = 1 + + // Additional memory overhead - in mb + val MEMORY_OVERHEAD = 384 + + // host to rack map - saved from allocation requests + // We are expecting this not to change. + // Note that it is possible for this to change : and RM will indicate that to us via update + // response to allocate. But we are punting on handling that for now. + private val hostToRack = new ConcurrentHashMap[String, String]() + private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() + + def newAllocator(conf: Configuration, + resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId, + args: ApplicationMasterArguments, + map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = { + + val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) + + + new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers, + args.workerMemory, args.workerCores, hostToCount, rackToCount) + } + + def newAllocator(conf: Configuration, + resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId, + maxWorkers: Int, workerMemory: Int, workerCores: Int, + map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = { + + val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) + + new YarnAllocationHandler(conf, resourceManager, appAttemptId, maxWorkers, + workerMemory, workerCores, hostToCount, rackToCount) + } + + // A simple method to copy the split info map. + private def generateNodeToWeight(conf: Configuration, input: collection.Map[String, collection.Set[SplitInfo]]) : + // host to count, rack to count + (Map[String, Int], Map[String, Int]) = { + + if (input == null) return (Map[String, Int](), Map[String, Int]()) + + val hostToCount = new HashMap[String, Int] + val rackToCount = new HashMap[String, Int] + + for ((host, splits) <- input) { + val hostCount = hostToCount.getOrElse(host, 0) + hostToCount.put(host, hostCount + splits.size) + + val rack = lookupRack(conf, host) + if (rack != null){ + val rackCount = rackToCount.getOrElse(host, 0) + rackToCount.put(host, rackCount + splits.size) + } + } + + (hostToCount.toMap, rackToCount.toMap) + } + + def lookupRack(conf: Configuration, host: String): String = { + if (! hostToRack.contains(host)) populateRackInfo(conf, host) + hostToRack.get(host) + } + + def fetchCachedHostsForRack(rack: String): Option[Set[String]] = { + val set = rackToHostSet.get(rack) + if (set == null) return None + + // No better way to get a Set[String] from JSet ? + val convertedSet: collection.mutable.Set[String] = set + Some(convertedSet.toSet) + } + + def populateRackInfo(conf: Configuration, hostname: String) { + Utils.checkHost(hostname) + + if (!hostToRack.containsKey(hostname)) { + // If there are repeated failures to resolve, all to an ignore list ? + val rackInfo = RackResolver.resolve(conf, hostname) + if (rackInfo != null && rackInfo.getNetworkLocation != null) { + val rack = rackInfo.getNetworkLocation + hostToRack.put(hostname, rack) + if (! rackToHostSet.containsKey(rack)) { + rackToHostSet.putIfAbsent(rack, Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]())) + } + rackToHostSet.get(rack).add(hostname) + + // Since RackResolver caches, we are disabling this for now ... + } /* else { + // right ? Else we will keep calling rack resolver in case we cant resolve rack info ... + hostToRack.put(hostname, null) + } */ + } + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala new file mode 100644 index 0000000000..ed732d36bf --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -0,0 +1,42 @@ +package spark.scheduler.cluster + +import spark._ +import spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} +import org.apache.hadoop.conf.Configuration + +/** + * + * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done + */ +private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { + + def this(sc: SparkContext) = this(sc, new Configuration()) + + // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate + // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?) + // Subsequent creations are ignored - since nodes are already allocated by then. + + + // By default, rack is unknown + override def getRackForHost(hostPort: String): Option[String] = { + val host = Utils.parseHostPort(hostPort)._1 + val retval = YarnAllocationHandler.lookupRack(conf, host) + if (retval != null) Some(retval) else None + } + + // By default, if rack is unknown, return nothing + override def getCachedHostsForRack(rack: String): Option[Set[String]] = { + if (rack == None || rack == null) return None + + YarnAllocationHandler.fetchCachedHostsForRack(rack) + } + + override def postStartHook() { + val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc) + if (sparkContextInitialized){ + // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt + Thread.sleep(3000L) + } + logInfo("YarnClusterScheduler.postStartHook done") + } +} diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala index 35300cea58..a0652d7fc7 100644 --- a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala +++ b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -4,4 +4,7 @@ trait HadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala index 7afdbff320..7fdbe322fd 100644 --- a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala +++ b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -7,4 +7,7 @@ trait HadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala new file mode 100644 index 0000000000..a0fb4fe25d --- /dev/null +++ b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala @@ -0,0 +1,23 @@ +package spark.deploy +import org.apache.hadoop.conf.Configuration + + +/** + * Contains util methods to interact with Hadoop from spark. + */ +object SparkHadoopUtil { + + def getUserNameFromEnvironment(): String = { + // defaulting to -D ... + System.getProperty("user.name") + } + + def runAsUser(func: (Product) => Unit, args: Product) { + + // Add support, if exists - for now, simply run func ! + func(args) + } + + // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems + def newConfiguration(): Configuration = new Configuration() +} diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java new file mode 100644 index 0000000000..a4bb4bc701 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -0,0 +1,72 @@ +package spark.network.netty; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelOption; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioSocketChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class FileClient { + + private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); + private FileClientHandler handler = null; + private Channel channel = null; + private Bootstrap bootstrap = null; + private int connectTimeout = 60*1000; // 1 min + + public FileClient(FileClientHandler handler, int connectTimeout) { + this.handler = handler; + this.connectTimeout = connectTimeout; + } + + public void init() { + bootstrap = new Bootstrap(); + bootstrap.group(new OioEventLoopGroup()) + .channel(OioSocketChannel.class) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) + .handler(new FileClientChannelInitializer(handler)); + } + + public void connect(String host, int port) { + try { + // Start the connection attempt. + channel = bootstrap.connect(host, port).sync().channel(); + // ChannelFuture cf = channel.closeFuture(); + //cf.addListener(new ChannelCloseListener(this)); + } catch (InterruptedException e) { + close(); + } + } + + public void waitForClose() { + try { + channel.closeFuture().sync(); + } catch (InterruptedException e) { + LOG.warn("FileClient interrupted", e); + } + } + + public void sendRequest(String file) { + //assert(file == null); + //assert(channel == null); + channel.write(file + "\r\n"); + } + + public void close() { + if(channel != null) { + channel.close(); + channel = null; + } + if ( bootstrap!=null) { + bootstrap.shutdown(); + bootstrap = null; + } + } +} diff --git a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java new file mode 100644 index 0000000000..af25baf641 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java @@ -0,0 +1,24 @@ +package spark.network.netty; + +import io.netty.buffer.BufType; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.string.StringEncoder; + + +class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> { + + private FileClientHandler fhandler; + + public FileClientChannelInitializer(FileClientHandler handler) { + fhandler = handler; + } + + @Override + public void initChannel(SocketChannel channel) { + // file no more than 2G + channel.pipeline() + .addLast("encoder", new StringEncoder(BufType.BYTE)) + .addLast("handler", fhandler); + } +} diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java new file mode 100644 index 0000000000..9fc9449827 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClientHandler.java @@ -0,0 +1,43 @@ +package spark.network.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundByteHandlerAdapter; + + +abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { + + private FileHeader currentHeader = null; + + private volatile boolean handlerCalled = false; + + public boolean isComplete() { + return handlerCalled; + } + + public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); + public abstract void handleError(String blockId); + + @Override + public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { + // Use direct buffer if possible. + return ctx.alloc().ioBuffer(); + } + + @Override + public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) { + // get header + if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { + currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); + } + // get file + if(in.readableBytes() >= currentHeader.fileLen()) { + handle(ctx, in, currentHeader); + handlerCalled = true; + currentHeader = null; + ctx.close(); + } + } + +} + diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java new file mode 100644 index 0000000000..dd3a557ae5 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServer.java @@ -0,0 +1,86 @@ +package spark.network.netty; + +import java.net.InetSocketAddress; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelOption; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioServerSocketChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Server that accept the path of a file an echo back its content. + */ +class FileServer { + + private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); + + private ServerBootstrap bootstrap = null; + private ChannelFuture channelFuture = null; + private int port = 0; + private Thread blockingThread = null; + + public FileServer(PathResolver pResolver, int port) { + InetSocketAddress addr = new InetSocketAddress(port); + + // Configure the server. + bootstrap = new ServerBootstrap(); + bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) + .channel(OioServerSocketChannel.class) + .option(ChannelOption.SO_BACKLOG, 100) + .option(ChannelOption.SO_RCVBUF, 1500) + .childHandler(new FileServerChannelInitializer(pResolver)); + // Start the server. + channelFuture = bootstrap.bind(addr); + try { + // Get the address we bound to. + InetSocketAddress boundAddress = + ((InetSocketAddress) channelFuture.sync().channel().localAddress()); + this.port = boundAddress.getPort(); + } catch (InterruptedException ie) { + this.port = 0; + } + } + + /** + * Start the file server asynchronously in a new thread. + */ + public void start() { + blockingThread = new Thread() { + public void run() { + try { + channelFuture.channel().closeFuture().sync(); + LOG.info("FileServer exiting"); + } catch (InterruptedException e) { + LOG.error("File server start got interrupted", e); + } + // NOTE: bootstrap is shutdown in stop() + } + }; + blockingThread.setDaemon(true); + blockingThread.start(); + } + + public int getPort() { + return port; + } + + public void stop() { + // Close the bound channel. + if (channelFuture != null) { + channelFuture.channel().close(); + channelFuture = null; + } + // Shutdown bootstrap. + if (bootstrap != null) { + bootstrap.shutdown(); + bootstrap = null; + } + // TODO: Shutdown all accepted channels as well ? + } +} diff --git a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java new file mode 100644 index 0000000000..8f1f5c65cd --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java @@ -0,0 +1,25 @@ +package spark.network.netty; + +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.DelimiterBasedFrameDecoder; +import io.netty.handler.codec.Delimiters; +import io.netty.handler.codec.string.StringDecoder; + + +class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> { + + PathResolver pResolver; + + public FileServerChannelInitializer(PathResolver pResolver) { + this.pResolver = pResolver; + } + + @Override + public void initChannel(SocketChannel channel) { + channel.pipeline() + .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter())) + .addLast("strDecoder", new StringDecoder()) + .addLast("handler", new FileServerHandler(pResolver)); + } +} diff --git a/core/src/main/java/spark/network/netty/FileServerHandler.java b/core/src/main/java/spark/network/netty/FileServerHandler.java new file mode 100644 index 0000000000..a78eddb1b5 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServerHandler.java @@ -0,0 +1,65 @@ +package spark.network.netty; + +import java.io.File; +import java.io.FileInputStream; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundMessageHandlerAdapter; +import io.netty.channel.DefaultFileRegion; + + +class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { + + PathResolver pResolver; + + public FileServerHandler(PathResolver pResolver){ + this.pResolver = pResolver; + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, String blockId) { + String path = pResolver.getAbsolutePath(blockId); + // if getFilePath returns null, close the channel + if (path == null) { + //ctx.close(); + return; + } + File file = new File(path); + if (file.exists()) { + if (!file.isFile()) { + //logger.info("Not a file : " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + long length = file.length(); + if (length > Integer.MAX_VALUE || length <= 0) { + //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + int len = new Long(length).intValue(); + //logger.info("Sending block "+blockId+" filelen = "+len); + //logger.info("header = "+ (new FileHeader(len, blockId)).buffer()); + ctx.write((new FileHeader(len, blockId)).buffer()); + try { + ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) + .getChannel(), 0, file.length())); + } catch (Exception e) { + //logger.warning("Exception when sending file : " + file.getAbsolutePath()); + e.printStackTrace(); + } + } else { + //logger.warning("File not found: " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + } + ctx.flush(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + ctx.close(); + } +} diff --git a/core/src/main/java/spark/network/netty/PathResolver.java b/core/src/main/java/spark/network/netty/PathResolver.java new file mode 100755 index 0000000000..302411672c --- /dev/null +++ b/core/src/main/java/spark/network/netty/PathResolver.java @@ -0,0 +1,12 @@ +package spark.network.netty;
+
+
+public interface PathResolver {
+ /**
+ * Get the absolute path of the file
+ *
+ * @param fileId
+ * @return the absolute path of file
+ */
+ public String getAbsolutePath(String fileId);
+}
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index c27ed36406..3239f4c385 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -1,14 +1,19 @@ package spark -import executor.{ShuffleReadMetrics, TaskMetrics} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import spark.storage.{DelegateBlockFetchTracker, BlockManagerId} -import util.{CompletionIterator, TimedIterator} +import spark.executor.{ShuffleReadMetrics, TaskMetrics} +import spark.serializer.Serializer +import spark.storage.BlockManagerId +import spark.util.CompletionIterator + private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = { + + override def fetch[K, V]( + shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) = { + logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) val blockManager = SparkEnv.get.blockManager @@ -48,18 +53,18 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin } } - val blockFetcherItr = blockManager.getMultiple(blocksByAddress) - val itr = new TimedIterator(blockFetcherItr.flatMap(unpackBlock)) with DelegateBlockFetchTracker - itr.setDelegate(blockFetcherItr) + val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) + val itr = blockFetcherItr.flatMap(unpackBlock) + CompletionIterator[(K,V), Iterator[(K,V)]](itr, { val shuffleMetrics = new ShuffleReadMetrics - shuffleMetrics.shuffleReadMillis = itr.getNetMillis - shuffleMetrics.remoteFetchTime = itr.remoteFetchTime - shuffleMetrics.fetchWaitTime = itr.fetchWaitTime - shuffleMetrics.remoteBytesRead = itr.remoteBytesRead - shuffleMetrics.totalBlocksFetched = itr.totalBlocks - shuffleMetrics.localBlocksFetched = itr.numLocalBlocks - shuffleMetrics.remoteBlocksFetched = itr.numRemoteBlocks + shuffleMetrics.shuffleFinishTime = System.currentTimeMillis + shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime + shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime + shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead + shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks + shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks + shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks metrics.shuffleReadMetrics = Some(shuffleMetrics) }) } diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala index 98525b99c8..d5e7132ff9 100644 --- a/core/src/main/scala/spark/ClosureCleaner.scala +++ b/core/src/main/scala/spark/ClosureCleaner.scala @@ -5,15 +5,22 @@ import java.lang.reflect.Field import scala.collection.mutable.Map import scala.collection.mutable.Set -import org.objectweb.asm.{ClassReader, MethodVisitor, Type} -import org.objectweb.asm.commons.EmptyVisitor +import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} import org.objectweb.asm.Opcodes._ +import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream} private[spark] object ClosureCleaner extends Logging { // Get an ASM class reader for a given class from the JAR that loaded it private def getClassReader(cls: Class[_]): ClassReader = { - new ClassReader(cls.getResourceAsStream( - cls.getName.replaceFirst("^.*\\.", "") + ".class")) + // Copy data over, before delegating to ClassReader - else we can run out of open file handles. + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + // todo: Fixme - continuing with earlier behavior ... + if (resourceStream == null) return new ClassReader(resourceStream) + + val baos = new ByteArrayOutputStream(128) + Utils.copyStream(resourceStream, baos, true) + new ClassReader(new ByteArrayInputStream(baos.toByteArray)) } // Check whether a class represents a Scala closure @@ -154,10 +161,10 @@ private[spark] object ClosureCleaner extends Logging { } } -private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor { +private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - return new EmptyVisitor { + return new MethodVisitor(ASM4) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { @@ -180,7 +187,7 @@ private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) exten } } -private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor { +private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { var myName: String = null override def visit(version: Int, access: Int, name: String, sig: String, @@ -190,7 +197,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisi override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - return new EmptyVisitor { + return new MethodVisitor(ASM4) { override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { val argTypes = Type.getArgumentTypes(desc) diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index 5eea907322..2af44aa383 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -25,10 +25,12 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { * @param shuffleId the shuffle id * @param rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output + * @param serializerClass class name of the serializer to use */ class ShuffleDependency[K, V]( @transient rdd: RDD[(K, V)], - val partitioner: Partitioner) + val partitioner: Partitioner, + val serializerClass: String = null) extends Dependency(rdd) { val shuffleId: Int = rdd.context.newShuffleId() diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala index a953081d24..40b0193f19 100644 --- a/core/src/main/scala/spark/FetchFailedException.scala +++ b/core/src/main/scala/spark/FetchFailedException.scala @@ -3,18 +3,25 @@ package spark import spark.storage.BlockManagerId private[spark] class FetchFailedException( - val bmAddress: BlockManagerId, - val shuffleId: Int, - val mapId: Int, - val reduceId: Int, + taskEndReason: TaskEndReason, + message: String, cause: Throwable) extends Exception { - - override def getMessage(): String = - "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId) + + def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) = + this(FetchFailed(bmAddress, shuffleId, mapId, reduceId), + "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId), + cause) + + def this (shuffleId: Int, reduceId: Int, cause: Throwable) = + this(FetchFailed(null, shuffleId, -1, reduceId), + "Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause) + + override def getMessage(): String = message + override def getCause(): Throwable = cause - def toTaskEndReason: TaskEndReason = - FetchFailed(bmAddress, shuffleId, mapId, reduceId) + def toTaskEndReason: TaskEndReason = taskEndReason + } diff --git a/core/src/main/scala/spark/HadoopWriter.scala b/core/src/main/scala/spark/HadoopWriter.scala index afcf9f6db4..5e8396edb9 100644 --- a/core/src/main/scala/spark/HadoopWriter.scala +++ b/core/src/main/scala/spark/HadoopWriter.scala @@ -2,14 +2,10 @@ package org.apache.hadoop.mapred import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text import java.text.SimpleDateFormat import java.text.NumberFormat import java.io.IOException -import java.net.URI import java.util.Date import spark.Logging @@ -24,7 +20,7 @@ import spark.SerializableWritable * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRedUtil with Serializable { - + private val now = new Date() private val conf = new SerializableWritable(jobConf) @@ -106,6 +102,12 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe } } + def commitJob() { + // always ? Or if cmtr.needsTaskCommit ? + val cmtr = getOutputCommitter() + cmtr.commitJob(getJobContext()) + } + def cleanup() { getOutputCommitter().cleanupJob(getJobContext()) } diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala index 7c1c1bb144..0fc8c31463 100644 --- a/core/src/main/scala/spark/Logging.scala +++ b/core/src/main/scala/spark/Logging.scala @@ -68,6 +68,10 @@ trait Logging { if (log.isErrorEnabled) log.error(msg, throwable) } + protected def isTraceEnabled(): Boolean = { + log.isTraceEnabled + } + // Method for ensuring that logging is initialized, to avoid having multiple // threads do it concurrently (as SLF4J initialization is not thread safe). protected def initLogging() { log } diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 50708d9cb1..0fc6427307 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -1,7 +1,6 @@ package spark import java.io._ -import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashMap @@ -11,6 +10,7 @@ import akka.actor._ import scala.concurrent.Await import akka.pattern.ask import akka.remote._ + import scala.concurrent.duration.Duration import akka.util.Timeout import scala.concurrent.duration._ @@ -40,10 +40,12 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac private[spark] class MapOutputTracker extends Logging { + private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + // Set to the MapOutputTrackerActor living on the driver var trackerActor: ActorRef = _ - var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. @@ -52,7 +54,7 @@ private[spark] class MapOutputTracker extends Logging { // Cache a serialized version of the output statuses for each shuffle to send them out faster var cacheGeneration = generation - val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] + private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup) @@ -60,7 +62,6 @@ private[spark] class MapOutputTracker extends Logging { // throw a SparkException if this fails. def askTracker(message: Any): Any = { try { - val timeout = 10.seconds val future = trackerActor.ask(message)(timeout) return Await.result(future, timeout) } catch { @@ -77,10 +78,9 @@ private[spark] class MapOutputTracker extends Logging { } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.get(shuffleId) != None) { + if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { @@ -101,8 +101,9 @@ private[spark] class MapOutputTracker extends Logging { } def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var array = mapStatuses(shuffleId) - if (array != null) { + var arrayOpt = mapStatuses.get(shuffleId) + if (arrayOpt.isDefined && arrayOpt.get != null) { + var array = arrayOpt.get array.synchronized { if (array(mapId) != null && array(mapId).location == bmAddress) { array(mapId) = null @@ -115,13 +116,14 @@ private[spark] class MapOutputTracker extends Logging { } // Remembers which map output locations are currently being fetched on a worker - val fetching = new HashSet[Int] + private val fetching = new HashSet[Int] // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + var fetchedStatuses: Array[MapStatus] = null fetching.synchronized { if (fetching.contains(shuffleId)) { // Someone else is fetching it; wait for them to be done @@ -132,31 +134,48 @@ private[spark] class MapOutputTracker extends Logging { case e: InterruptedException => } } - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId)) - } else { + } + + // Either while we waited the fetch happened successfully, or + // someone fetched it in between the get and the fetching.synchronized. + fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + // We have to do the fetch, get others to wait for us. fetching += shuffleId } } - // We won the race to fetch the output locs; do so - logInfo("Doing the fetch; tracker actor = " + trackerActor) - val host = System.getProperty("spark.hostname", Utils.localHostName) - // This try-finally prevents hangs due to timeouts: - var fetchedStatuses: Array[MapStatus] = null - try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]] - fetchedStatuses = deserializeStatuses(fetchedBytes) - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) - } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() + + if (fetchedStatuses == null) { + // We won the race to fetch the output locs; do so + logInfo("Doing the fetch; tracker actor = " + trackerActor) + val hostPort = Utils.localHostPort() + // This try-finally prevents hangs due to timeouts: + try { + val fetchedBytes = + askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]] + fetchedStatuses = deserializeStatuses(fetchedBytes) + logInfo("Got the output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } finally { + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } + } + } + if (fetchedStatuses != null) { + fetchedStatuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } } - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) + else{ + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing all output locations for shuffle " + shuffleId)) + } } else { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + statuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + } } } @@ -194,7 +213,8 @@ private[spark] class MapOutputTracker extends Logging { generationLock.synchronized { if (newGen > generation) { logInfo("Updating generation to " + newGen + " and clearing cache") - mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + mapStatuses.clear() generation = newGen } } @@ -232,10 +252,13 @@ private[spark] class MapOutputTracker extends Logging { // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. - def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { + private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - objOut.writeObject(statuses) + // Since statuses can be modified in parallel, sync on it + statuses.synchronized { + objOut.writeObject(statuses) + } objOut.close() out.toByteArray } @@ -243,7 +266,10 @@ private[spark] class MapOutputTracker extends Logging { // Opposite of serializeStatuses. def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = { val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) - objIn.readObject().asInstanceOf[Array[MapStatus]] + objIn.readObject(). + // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present + // comment this out - nulls could be due to missing location ? + asInstanceOf[Array[MapStatus]] // .filter( _ != null ) } } @@ -253,16 +279,13 @@ private[spark] object MapOutputTracker { // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If // any of the statuses is null (indicating a missing location due to a failed mapper), // throw a FetchFailedException. - def convertMapStatuses( + private def convertMapStatuses( shuffleId: Int, reduceId: Int, statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { - if (statuses == null) { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing all output locations for shuffle " + shuffleId)) - } + assert (statuses != null) statuses.map { - status => + status => if (status == null) { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing an output location for shuffle " + shuffleId)) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 2052d05788..fe812fe530 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -1,5 +1,6 @@ package spark +import java.nio.ByteBuffer import java.util.{Date, HashMap => JHashMap} import java.text.SimpleDateFormat @@ -11,6 +12,8 @@ import scala.reflect.{ ClassTag, classTag} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.mapred.FileOutputCommitter import org.apache.hadoop.mapred.FileOutputFormat import org.apache.hadoop.mapred.HadoopWriter @@ -18,7 +21,7 @@ import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil} import spark.partial.BoundedDouble import spark.partial.PartialResult @@ -53,7 +56,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag]( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, partitioner: Partitioner, - mapSideCombine: Boolean = true): RDD[(K, C)] = { + mapSideCombine: Boolean = true, + serializerClass: String = null): RDD[(K, C)] = { if (getKeyClass().isArray) { if (mapSideCombine) { throw new SparkException("Cannot use map-side combining with array keys.") @@ -62,19 +66,18 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag]( throw new SparkException("Default partitioner cannot partition array keys.") } } - val aggregator = - new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) + val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (self.partitioner == Some(partitioner)) { self.mapPartitions(aggregator.combineValuesByKey(_), true) } else if (mapSideCombine) { val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true) - val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner) + val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass) partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true) } else { // Don't apply map-side combiner. // A sanity check to make sure mergeCombiners is not defined. assert(mergeCombiners == null) - val values = new ShuffledRDD[K, V](self, partitioner) + val values = new ShuffledRDD[K, V](self, partitioner, serializerClass) values.mapPartitions(aggregator.combineValuesByKey(_), true) } } @@ -95,7 +98,16 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag]( * list concatenation, 0 for addition, or 1 for multiplication.). */ def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { - combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner) + // Serialize the zero value to a byte array so that we can get a new clone of it on each key + val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue) + val zeroArray = new Array[Byte](zeroBuffer.limit) + zeroBuffer.get(zeroArray) + + // When deserializing, use a lazy val to create just one instance of the serializer per task + lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() + def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) + + combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner) } /** @@ -185,11 +197,13 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag]( * partitioning of the resulting key-value pair RDD by passing a Partitioner. */ def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = { + // groupByKey shouldn't use map side combine because map side combine does not + // reduce the amount of data shuffled and requires all map side data be inserted + // into a hash table, leading to more objects in the old gen. def createCombiner(v: V) = ArrayBuffer(v) def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v - def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2 val bufs = combineByKey[ArrayBuffer[V]]( - createCombiner _, mergeValue _, mergeCombiners _, partitioner) + createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false) bufs.asInstanceOf[RDD[(K, Seq[V])]] } @@ -516,6 +530,16 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag]( } /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. Compress the result with the + * supplied codec. + */ + def saveAsHadoopFile[F <: OutputFormat[K, V]]( + path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) { + saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec) + } + + /** * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. */ @@ -546,8 +570,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag]( // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" <split #> <attempt # = spark task #> */ - val attemptId = new TaskAttemptID(jobtrackerID, - stageId, false, context.splitId, attemptNumber) + val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outputFormatClass.newInstance val committer = format.getOutputCommitter(hadoopContext) @@ -566,16 +589,31 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag]( * however we're only going to use this local OutputCommitter for * setupJob/commitJob, so we just use a dummy "map" task. */ - val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0) + val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0) val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) jobCommitter.setupJob(jobTaskContext) val count = self.context.runJob(self, writeShard _).sum + jobCommitter.commitJob(jobTaskContext) jobCommitter.cleanupJob(jobTaskContext) } /** * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. Compress with the supplied codec. + */ + def saveAsHadoopFile( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]], + codec: Class[_ <: CompressionCodec]) { + saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, + new JobConf(self.context.hadoopConfiguration), Some(codec)) + } + + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. */ def saveAsHadoopFile( @@ -583,11 +621,19 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf(self.context.hadoopConfiguration)) { + conf: JobConf = new JobConf(self.context.hadoopConfiguration), + codec: Option[Class[_ <: CompressionCodec]] = None) { conf.setOutputKeyClass(keyClass) conf.setOutputValueClass(valueClass) // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug conf.set("mapred.output.format.class", outputFormatClass.getName) + for (c <- codec) { + conf.setCompressMapOutput(true) + conf.set("mapred.output.compress", "true") + conf.setMapOutputCompressorClass(c) + conf.set("mapred.output.compression.codec", c.getCanonicalName) + conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) + } conf.setOutputCommitter(classOf[FileOutputCommitter]) FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf)) saveAsHadoopDataset(conf) @@ -638,6 +684,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag]( } self.context.runJob(self, writeToFile _) + writer.commitJob() writer.cleanup() } diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 6ee075315a..e88290fdb2 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -1,22 +1,23 @@ package spark -import java.net.URL -import java.util.{Date, Random} -import java.util.{HashMap => JHashMap} +import java.util.Random import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer + import scala.collection.mutable.HashMap import scala.reflect.{classTag, ClassTag} import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.TextOutputFormat import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} +import spark.broadcast.Broadcast import spark.Partitioner._ import spark.partial.BoundedDouble import spark.partial.CountEvaluator @@ -33,10 +34,13 @@ import spark.rdd.MapPartitionsWithIndexRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD import spark.rdd.ShuffledRDD -import spark.rdd.SubtractedRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD +import spark.rdd.ZippedPartitionsRDD2 +import spark.rdd.ZippedPartitionsRDD3 +import spark.rdd.ZippedPartitionsRDD4 import spark.storage.StorageLevel +import spark.util.BoundedPriorityQueue import SparkContext._ @@ -105,7 +109,7 @@ abstract class RDD[T: ClassTag]( // ======================================================================= /** A unique ID for this RDD (within its SparkContext). */ - val id = sc.newRddId() + val id: Int = sc.newRddId() /** A friendly name for this RDD */ var name: String = null @@ -116,9 +120,18 @@ abstract class RDD[T: ClassTag]( this } + /** User-defined generator of this RDD*/ + var generator = Utils.getCallSiteInfo.firstUserClass + + /** Reset generator*/ + def setGenerator(_generator: String) = { + generator = _generator + } + /** * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. Can only be called once on each RDD. + * it is computed. This can only be used to assign a new storage level if the RDD does not + * have a storage level set yet.. */ def persist(newLevel: StorageLevel): RDD[T] = { // TODO: Handle changes of StorageLevel @@ -138,6 +151,20 @@ abstract class RDD[T: ClassTag]( /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): RDD[T] = persist() + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + * @return This RDD. + */ + def unpersist(blocking: Boolean = true): RDD[T] = { + logInfo("Removing RDD " + id + " from persistence list") + sc.env.blockManager.master.removeRdd(id, blocking) + sc.persistentRdds.remove(id) + storageLevel = StorageLevel.NONE + this + } + /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel @@ -257,8 +284,8 @@ abstract class RDD[T: ClassTag]( def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = { var fraction = 0.0 var total = 0 - var multiplier = 3.0 - var initialCount = count() + val multiplier = 3.0 + val initialCount = count() var maxSelected = 0 if (initialCount > Integer.MAX_VALUE - 1) { @@ -339,13 +366,36 @@ abstract class RDD[T: ClassTag]( /** * Return an RDD created by piping elements to a forked external process. */ - def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command) + def pipe(command: String, env: Map[String, String]): RDD[String] = + new PipedRDD(this, command, env) + /** * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: Seq[String], env: Map[String, String]): RDD[String] = - new PipedRDD(this, command, env) + * The print behavior can be customized by providing two functions. + * + * @param command command to run in forked process. + * @param env environment variables to set. + * @param printPipeContext Before piping elements, this function is called as an oppotunity + * to pipe context data. Print line function (like out.println) will be + * passed as printPipeContext's parameter. + * @param printRDDElement Use this function to customize how to pipe elements. This function + * will be called with each RDD element as the 1st parameter, and the + * print line function (like out.println()) as the 2nd parameter. + * An example of pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the elements: + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2){f(e)} + * @return the result RDD + */ + def pipe( + command: Seq[String], + env: Map[String, String] = Map(), + printPipeContext: (String => Unit) => Unit = null, + printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = + new PipedRDD(this, command, env, + if (printPipeContext ne null) sc.clean(printPipeContext) else null, + if (printRDDElement ne null) sc.clean(printRDDElement) else null) /** * Return a new RDD by applying a function to each partition of this RDD. @@ -437,6 +487,31 @@ abstract class RDD[T: ClassTag]( */ def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) + /** + * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by + * applying a function to the zipped partitions. Assumes that all the RDDs have the + * *same number of partitions*, but does *not* require them to have the same number + * of elements in each partition. + */ + def zipPartitions[B: ClassManifest, V: ClassManifest]( + f: (Iterator[T], Iterator[B]) => Iterator[V], + rdd2: RDD[B]): RDD[V] = + new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) + + def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]( + f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V], + rdd2: RDD[B], + rdd3: RDD[C]): RDD[V] = + new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) + + def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]( + f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], + rdd2: RDD[B], + rdd3: RDD[C], + rdd4: RDD[D]): RDD[V] = + new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) + + // Actions (launch a job to return a value to the user program) /** @@ -452,7 +527,7 @@ abstract class RDD[T: ClassTag]( */ def foreachPartition(f: Iterator[T] => Unit) { val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => f(iter)) + sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) } /** @@ -685,6 +760,24 @@ abstract class RDD[T: ClassTag]( } /** + * Returns the top K elements from this RDD as defined by + * the specified implicit Ordering[T]. + * @param num the number of top elements to return + * @param ord the implicit ordering for T + * @return an array of top elements + */ + def top(num: Int)(implicit ord: Ordering[T]): Array[T] = { + mapPartitions { items => + val queue = new BoundedPriorityQueue[T](num) + queue ++= items + Iterator.single(queue) + }.reduce { (queue1, queue2) => + queue1 ++= queue2 + queue1 + }.toArray + } + + /** * Save this RDD as a text file, using string representations of elements. */ def saveAsTextFile(path: String) { @@ -693,6 +786,14 @@ abstract class RDD[T: ClassTag]( } /** + * Save this RDD as a compressed text file, using string representations of elements. + */ + def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) { + this.map(x => (NullWritable.get(), new Text(x.toString))) + .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec) + } + + /** * Save this RDD as a SequenceFile of serialized objects. */ def saveAsObjectFile(path: String) { @@ -750,7 +851,7 @@ abstract class RDD[T: ClassTag]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Record user function generating this RDD. */ - private[spark] val origin = Utils.getSparkCallSite + private[spark] val origin = Utils.formatSparkCallSite private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index 083ba9b8fa..5e7bb029eb 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -3,6 +3,7 @@ package spark import scala.reflect.ClassTag import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration import rdd.{CheckpointRDD, CoalescedRDD} @@ -66,14 +67,20 @@ private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T]) } } + // Create the output path for the checkpoint + val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) + val fs = path.getFileSystem(new Configuration()) + if (!fs.mkdirs(path)) { + throw new SparkException("Failed to create checkpoint path " + path) + } + // Save to file, and reload it as an RDD - val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString - rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _) - val newRDD = new CheckpointRDD[T](rdd.context, path) + rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _) + val newRDD = new CheckpointRDD[T](rdd.context, path.toString) // Change the dependencies and partitions of the RDD RDDCheckpointData.synchronized { - cpFile = Some(path) + cpFile = Some(path.toString) cpRDD = Some(newRDD) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala index 883a0152bb..edfde37da3 100644 --- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala @@ -19,6 +19,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapred.SequenceFileOutputFormat import org.apache.hadoop.mapred.OutputCommitter import org.apache.hadoop.mapred.FileOutputCommitter +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.Writable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.BytesWritable @@ -63,7 +64,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported * file system. */ - def saveAsSequenceFile(path: String) { + def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { def anyToWritable[U <% Writable](u: U): Writable = u val keyClass = getWritableClass[K] @@ -73,14 +74,18 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" ) val format = classOf[SequenceFileOutputFormat[Writable, Writable]] + val jobConf = new JobConf(self.context.hadoopConfiguration) if (!convertKey && !convertValue) { - self.saveAsHadoopFile(path, keyClass, valueClass, format) + self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec) } else if (!convertKey && convertValue) { - self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format) + self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) } else if (convertKey && !convertValue) { - self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(path, keyClass, valueClass, format) + self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) } else if (convertKey && convertValue) { - self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format) + self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) } } } diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala index 442e9f0269..9513a00126 100644 --- a/core/src/main/scala/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/spark/ShuffleFetcher.scala @@ -1,13 +1,16 @@ package spark -import executor.TaskMetrics +import spark.executor.TaskMetrics +import spark.serializer.Serializer + private[spark] abstract class ShuffleFetcher { /** * Fetch the shuffle outputs for a given ShuffleDependency. * @return An iterator over the elements of the fetched shuffle outputs. */ - def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) : Iterator[(K,V)] + def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, + serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[(K,V)] /** Stop the fetcher */ def stop() {} diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala index d4e1157250..f8a4c4e489 100644 --- a/core/src/main/scala/spark/SizeEstimator.scala +++ b/core/src/main/scala/spark/SizeEstimator.scala @@ -198,7 +198,7 @@ private[spark] object SizeEstimator extends Logging { val elem = JArray.get(array, index) size += SizeEstimator.estimate(elem, state.visited) } - state.size += ((length / 100.0) * size).toLong + state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong } } } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 7272a592a5..ef6de87193 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -1,48 +1,56 @@ package spark import java.io._ -import java.util.concurrent.atomic.AtomicInteger import java.net.URI +import java.util.Properties +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger +import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ + import scala.reflect.{ ClassTag, classTag} -import org.apache.hadoop.fs.Path +import scala.util.DynamicVariable +import scala.collection.mutable.{ConcurrentMap, HashMap} + +import akka.actor.Actor._ + import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.mapred.SequenceFileInputFormat -import org.apache.hadoop.io.Writable -import org.apache.hadoop.io.IntWritable -import org.apache.hadoop.io.LongWritable -import org.apache.hadoop.io.FloatWritable -import org.apache.hadoop.io.DoubleWritable +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.ArrayWritable import org.apache.hadoop.io.BooleanWritable import org.apache.hadoop.io.BytesWritable -import org.apache.hadoop.io.ArrayWritable +import org.apache.hadoop.io.DoubleWritable +import org.apache.hadoop.io.FloatWritable +import org.apache.hadoop.io.IntWritable +import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text +import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.FileInputFormat +import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.SequenceFileInputFormat import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.{Job => NewHadoopJob} +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} + import org.apache.mesos.MesosNativeLibrary -import spark.deploy.LocalSparkCluster -import spark.partial.ApproximateEvaluator -import spark.partial.PartialResult +import spark.deploy.{LocalSparkCluster, SparkHadoopUtil} +import spark.partial.{ApproximateEvaluator, PartialResult} import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD} -import spark.scheduler._ +import spark.scheduler.{DAGScheduler, ResultTask, ShuffleMapTask, SparkListener, SplitInfo, Stage, StageInfo, TaskScheduler} +import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, ClusterScheduler} import spark.scheduler.local.LocalScheduler -import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import spark.storage.BlockManagerUI +import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo} import spark.util.{MetadataCleaner, TimeStampedHashMap} -import spark.storage.{StorageStatus, StorageUtils, RDDInfo} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -60,7 +68,10 @@ class SparkContext( val appName: String, val sparkHome: String = null, val jars: Seq[String] = Nil, - val environment: Map[String, String] = Map()) + val environment: Map[String, String] = Map(), + // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too. + // This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host + val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map()) extends Logging { // Ensure logging is initialized before we spawn any threads @@ -68,7 +79,7 @@ class SparkContext( // Set Spark driver host and port system properties if (System.getProperty("spark.driver.host") == null) { - System.setProperty("spark.driver.host", Utils.localIpAddress) + System.setProperty("spark.driver.host", Utils.localHostName()) } if (System.getProperty("spark.driver.port") == null) { System.setProperty("spark.driver.port", "0") @@ -95,24 +106,29 @@ class SparkContext( private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() + private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) // Add each JAR given through the constructor - jars.foreach { addJar(_) } + if (jars != null) { + jars.foreach { addJar(_) } + } // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner - for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", - "SPARK_TESTING")) { + for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) { val value = System.getenv(key) if (value != null) { executorEnvs(key) = value } } - executorEnvs ++= environment + // Since memory can be set with a system property too, use that + executorEnvs("SPARK_MEM") = SparkContext.executorMemoryRequested + "m" + if (environment != null) { + executorEnvs ++= environment + } // Create and start the scheduler private var taskScheduler: TaskScheduler = { @@ -144,14 +160,12 @@ class SparkContext( scheduler case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => - // Check to make sure SPARK_MEM <= memoryPerSlave. Otherwise Spark will just hang. + // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. val memoryPerSlaveInt = memoryPerSlave.toInt - val sparkMemEnv = System.getenv("SPARK_MEM") - val sparkMemEnvInt = if (sparkMemEnv != null) Utils.memoryStringToMb(sparkMemEnv) else 512 - if (sparkMemEnvInt > memoryPerSlaveInt) { + if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { throw new SparkException( - "Slave memory (%d MB) cannot be smaller than SPARK_MEM (%d MB)".format( - memoryPerSlaveInt, sparkMemEnvInt)) + "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( + memoryPerSlaveInt, SparkContext.executorMemoryRequested)) } val scheduler = new ClusterScheduler(this) @@ -165,6 +179,22 @@ class SparkContext( } scheduler + case "yarn-standalone" => + val scheduler = try { + val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(this).asInstanceOf[ClusterScheduler] + } catch { + // TODO: Enumerate the exact reasons why it can fail + // But irrespective of it, it means we cannot proceed ! + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem) + scheduler.initialize(backend) + scheduler + case _ => if (MESOS_REGEX.findFirstIn(master).isEmpty) { logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) @@ -184,12 +214,12 @@ class SparkContext( } taskScheduler.start() - private var dagScheduler = new DAGScheduler(taskScheduler) + @volatile private var dagScheduler = new DAGScheduler(taskScheduler) dagScheduler.start() /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = { - val conf = new Configuration() + val conf = SparkHadoopUtil.newConfiguration() // Explicitly check for S3 environment variables if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) @@ -208,6 +238,22 @@ class SparkContext( private[spark] var checkpointDir: Option[String] = None + // Thread Local variable that can be used by users to pass information down the stack + private val localProperties = new DynamicVariable[Properties](null) + + def initLocalProperties() { + localProperties.value = new Properties() + } + + def addLocalProperties(key: String, value: String) { + if(localProperties.value == null) { + localProperties.value = new Properties() + } + localProperties.value.setProperty(key,value) + } + // Post init + taskScheduler.postStartHook() + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ @@ -472,7 +518,7 @@ class SparkContext( */ def getExecutorMemoryStatus: Map[String, (Long, Long)] = { env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => - (blockManagerId.ip + ":" + blockManagerId.port, mem) + (blockManagerId.host + ":" + blockManagerId.port, mem) } } @@ -480,7 +526,7 @@ class SparkContext( * Return information about what RDDs are cached, if they are in mem or on disk, how much space * they take, etc. */ - def getRDDStorageInfo : Array[RDDInfo] = { + def getRDDStorageInfo: Array[RDDInfo] = { StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this) } @@ -491,7 +537,7 @@ class SparkContext( /** * Return information about blocks stored in all of the slaves */ - def getExecutorStorageStatus : Array[StorageStatus] = { + def getExecutorStorageStatus: Array[StorageStatus] = { env.blockManager.master.getStorageStatus } @@ -509,13 +555,18 @@ class SparkContext( * filesystems), or an HTTP, HTTPS or FTP URI. */ def addJar(path: String) { - val uri = new URI(path) - val key = uri.getScheme match { - case null | "file" => env.httpFileServer.addJar(new File(uri.getPath)) - case _ => path + if (null == path) { + logWarning("null specified as parameter to addJar", + new SparkException("null specified as parameter to addJar")) + } else { + val uri = new URI(path) + val key = uri.getScheme match { + case null | "file" => env.httpFileServer.addJar(new File(uri.getPath)) + case _ => path + } + addedJars(key) = System.currentTimeMillis + logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key)) } - addedJars(key) = System.currentTimeMillis - logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key)) } /** @@ -528,10 +579,13 @@ class SparkContext( /** Shut down the SparkContext. */ def stop() { - if (dagScheduler != null) { + // Do this only if not stopped already - best case effort. + // prevent NPE if stopped more than once. + val dagSchedulerCopy = dagScheduler + dagScheduler = null + if (dagSchedulerCopy != null) { metadataCleaner.cancel() - dagScheduler.stop() - dagScheduler = null + dagSchedulerCopy.stop() taskScheduler = null // TODO: Cache.stop()? env.stop() @@ -547,6 +601,7 @@ class SparkContext( } } + /** * Get Spark's home location from either a value set through the constructor, * or the spark.home Java property, or the SPARK_HOME environment variable @@ -576,10 +631,10 @@ class SparkContext( partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - val callSite = Utils.getSparkCallSite + val callSite = Utils.formatSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler) + val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() result @@ -658,12 +713,11 @@ class SparkContext( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], - timeout: Long - ): PartialResult[R] = { - val callSite = Utils.getSparkCallSite + timeout: Long): PartialResult[R] = { + val callSite = Utils.formatSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout) + val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") result } @@ -686,7 +740,7 @@ class SparkContext( */ def setCheckpointDir(dir: String, useExisting: Boolean = false) { val path = new Path(dir) - val fs = path.getFileSystem(new Configuration()) + val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration()) if (!useExisting) { if (fs.exists(path)) { throw new Exception("Checkpoint directory '" + path + "' already exists.") @@ -829,6 +883,15 @@ object SparkContext { /** Find the JAR that contains the class of a particular object */ def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass) + + /** Get the amount of memory per executor requested through system properties or SPARK_MEM */ + private[spark] val executorMemoryRequested = { + // TODO: Might need to add some extra memory for the non-heap parts of the JVM + Option(System.getProperty("spark.executor.memory")) + .orElse(Option(System.getenv("SPARK_MEM"))) + .map(Utils.memoryStringToMb) + .getOrElse(512) + } } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 144ddea35f..89d52064e1 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -1,14 +1,19 @@ package spark +import collection.mutable +import serializer.Serializer + import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem} import akka.remote.RemoteActorRefProvider -import serializer.Serializer import spark.broadcast.BroadcastManager import spark.storage.BlockManager import spark.storage.BlockManagerMaster import spark.network.ConnectionManager +import spark.serializer.{Serializer, SerializerManager} import spark.util.AkkaUtils +import spark.api.python.PythonWorkerFactory + /** * Holds all the runtime environment objects for a running Spark instance (either master or worker), @@ -20,6 +25,7 @@ import spark.util.AkkaUtils class SparkEnv ( val executorId: String, val actorSystem: ActorSystem, + val serializerManager: SerializerManager, val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -29,10 +35,16 @@ class SparkEnv ( val blockManager: BlockManager, val connectionManager: ConnectionManager, val httpFileServer: HttpFileServer, - val sparkFilesDir: String - ) { + val sparkFilesDir: String, + // To be set only as part of initialization of SparkContext. + // (executorId, defaultHostPort) => executorHostPort + // If executorId is NOT found, return defaultHostPort + var executorIdToHostPort: Option[(String, String) => String]) { + + private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() def stop() { + pythonWorkers.foreach { case(key, worker) => worker.stop() } httpFileServer.stop() mapOutputTracker.stop() shuffleFetcher.stop() @@ -45,6 +57,23 @@ class SparkEnv ( // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. //actorSystem.awaitTermination() } + + def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create() + } + } + + def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = { + val env = SparkEnv.get + if (env.executorIdToHostPort.isEmpty) { + // default to using host, not host port. Relevant to non cluster modes. + return defaultHostPort + } + + env.executorIdToHostPort.get(executorId, defaultHostPort) + } } object SparkEnv extends Logging { @@ -73,6 +102,16 @@ object SparkEnv extends Logging { System.setProperty("spark.driver.port", boundPort.toString) } + // set only if unset until now. + if (System.getProperty("spark.hostPort", null) == null) { + if (!isDriver){ + // unexpected + Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set") + } + Utils.checkHost(hostname) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + } + val classLoader = Thread.currentThread.getContextClassLoader // Create an instance of the class named by the given Java system property, or by @@ -82,16 +121,23 @@ object SparkEnv extends Logging { Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] } - val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") - + val serializerManager = new SerializerManager + + val serializer = serializerManager.setDefault( + System.getProperty("spark.serializer", "spark.JavaSerializer")) + + val closureSerializer = serializerManager.get( + System.getProperty("spark.closure.serializer", "spark.JavaSerializer")) + def registerOrLookup(name: String, newActor: => Actor): ActorRef = { if (isDriver) { logInfo("Registering " + name) actorSystem.actorOf(Props(newActor), name = name) } else { - val driverIp: String = System.getProperty("spark.driver.host", "localhost") + val driverHost: String = System.getProperty("spark.driver.host", "localhost") val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt - val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, name) + Utils.checkHost(driverHost, "Expected hostname") + val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name) logInfo("Connecting to " + name + ": " + url) actorSystem.actorFor(url) } @@ -106,9 +152,6 @@ object SparkEnv extends Logging { val broadcastManager = new BroadcastManager(isDriver) - val closureSerializer = instantiateClass[Serializer]( - "spark.closure.serializer", "spark.JavaSerializer") - val cacheManager = new CacheManager(blockManager) // Have to assign trackerActor after initialization as MapOutputTrackerActor @@ -143,6 +186,7 @@ object SparkEnv extends Logging { new SparkEnv( executorId, actorSystem, + serializerManager, serializer, closureSerializer, cacheManager, @@ -152,7 +196,7 @@ object SparkEnv extends Logging { blockManager, connectionManager, httpFileServer, - sparkFilesDir) + sparkFilesDir, + None) } - } diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala index 420c54bc9a..8140cba084 100644 --- a/core/src/main/scala/spark/TaskEndReason.scala +++ b/core/src/main/scala/spark/TaskEndReason.scala @@ -14,9 +14,19 @@ private[spark] case object Success extends TaskEndReason private[spark] case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it -private[spark] -case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason +private[spark] case class FetchFailed( + bmAddress: BlockManagerId, + shuffleId: Int, + mapId: Int, + reduceId: Int) + extends TaskEndReason -private[spark] case class ExceptionFailure(exception: Throwable) extends TaskEndReason +private[spark] case class ExceptionFailure( + className: String, + description: String, + stackTrace: Array[StackTraceElement]) + extends TaskEndReason private[spark] case class OtherFailure(message: String) extends TaskEndReason + +private[spark] case class TaskResultTooBigFailure() extends TaskEndReason diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index cdccb8b336..e02507f83e 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,14 +1,16 @@ package spark import java.io._ -import java.net._ +import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} import java.util.{Locale, Random, UUID} -import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} + +import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} +import java.util.regex.Pattern import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConversions._ import scala.io.Source import scala.reflect.ClassTag @@ -18,11 +20,14 @@ import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder import spark.serializer.SerializerInstance +import spark.deploy.SparkHadoopUtil + /** * Various utility methods used by Spark. */ private object Utils extends Logging { + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -73,6 +78,40 @@ private object Utils extends Logging { return buf } + private val shutdownDeletePaths = new collection.mutable.HashSet[String]() + + // Register the path to be deleted via shutdown hook + def registerShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths += absolutePath + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in IOException and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + val retval = shutdownDeletePaths.synchronized { + shutdownDeletePaths.find { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + }.isDefined + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + /** Create a temporary directory inside the given parent directory */ def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = { var attempts = 0 @@ -81,8 +120,8 @@ private object Utils extends Logging { while (dir == null) { attempts += 1 if (attempts > maxAttempts) { - throw new IOException("Failed to create a temp directory after " + maxAttempts + - " attempts!") + throw new IOException("Failed to create a temp directory (under " + root + ") after " + + maxAttempts + " attempts!") } try { dir = new File(root, "spark-" + UUID.randomUUID.toString) @@ -91,13 +130,17 @@ private object Utils extends Logging { } } catch { case e: IOException => ; } } + + registerShutdownDeleteDir(dir) + // Add a shutdown hook to delete the temp dir when the JVM exits Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { override def run() { - Utils.deleteRecursively(dir) + // Attempt to delete if some patch which is parent of this is not already registered. + if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) } }) - return dir + dir } /** Copy all data from an InputStream to an OutputStream */ @@ -140,40 +183,35 @@ private object Utils extends Logging { Utils.copyStream(in, out, true) if (targetFile.exists && !Files.equal(tempFile, targetFile)) { tempFile.delete() - throw new SparkException("File " + targetFile + " exists and does not match contents of" + - " " + url) + throw new SparkException( + "File " + targetFile + " exists and does not match contents of" + " " + url) } else { Files.move(tempFile, targetFile) } case "file" | null => - val sourceFile = if (uri.isAbsolute) { - new File(uri) - } else { - new File(url) - } - if (targetFile.exists && !Files.equal(sourceFile, targetFile)) { - throw new SparkException("File " + targetFile + " exists and does not match contents of" + - " " + url) - } else { - // Remove the file if it already exists - targetFile.delete() - // Symlink the file locally. - if (uri.isAbsolute) { - // url is absolute, i.e. it starts with "file:///". Extract the source - // file's absolute path from the url. - val sourceFile = new File(uri) - logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) - FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath) + // In the case of a local file, copy the local file to the target directory. + // Note the difference between uri vs url. + val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) + if (targetFile.exists) { + // If the target file already exists, warn the user if + if (!Files.equal(sourceFile, targetFile)) { + throw new SparkException( + "File " + targetFile + " exists and does not match contents of" + " " + url) } else { - // url is not absolute, i.e. itself is the path to the source file. - logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath) - FileUtil.symLink(url, targetFile.getAbsolutePath) + // Do nothing if the file contents are the same, i.e. this file has been copied + // previously. + logInfo(sourceFile.getAbsolutePath + " has been previously copied to " + + targetFile.getAbsolutePath) } + } else { + // The file does not exist in the target directory. Copy it there. + logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) + Files.copy(sourceFile, targetFile) } case _ => // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others val uri = new URI(url) - val conf = new Configuration() + val conf = SparkHadoopUtil.newConfiguration() val fs = FileSystem.get(uri, conf) val in = fs.open(new Path(uri)) val out = new FileOutputStream(tempFile) @@ -232,8 +270,10 @@ private object Utils extends Logging { /** * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). + * Note, this is typically not used from within core spark. */ lazy val localIpAddress: String = findLocalIpAddress() + lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress) private def findLocalIpAddress(): String = { val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") @@ -271,6 +311,8 @@ private object Utils extends Logging { * hostname it reports to the master. */ def setCustomHostname(hostname: String) { + // DEBUG code + Utils.checkHost(hostname) customHostname = Some(hostname) } @@ -278,7 +320,91 @@ private object Utils extends Logging { * Get the local machine's hostname. */ def localHostName(): String = { - customHostname.getOrElse(InetAddress.getLocalHost.getHostName) + customHostname.getOrElse(localIpAddressHostname) + } + + def getAddressHostName(address: String): String = { + InetAddress.getByName(address).getHostName + } + + def localHostPort(): String = { + val retval = System.getProperty("spark.hostPort", null) + if (retval == null) { + logErrorWithStack("spark.hostPort not set but invoking localHostPort") + return localHostName() + } + + retval + } + +/* + // Used by DEBUG code : remove when all testing done + private val ipPattern = Pattern.compile("^[0-9]+(\\.[0-9]+)*$") + def checkHost(host: String, message: String = "") { + // Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous ! + // if (host.matches("^[0-9]+(\\.[0-9]+)*$")) { + if (ipPattern.matcher(host).matches()) { + Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message) + } + if (Utils.parseHostPort(host)._2 != 0){ + Utils.logErrorWithStack("Unexpected to have host " + host + " which has port in it. Message " + message) + } + } + + // Used by DEBUG code : remove when all testing done + def checkHostPort(hostPort: String, message: String = "") { + val (host, port) = Utils.parseHostPort(hostPort) + checkHost(host) + if (port <= 0){ + Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message) + } + } + + // Used by DEBUG code : remove when all testing done + def logErrorWithStack(msg: String) { + try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } + // temp code for debug + System.exit(-1) + } +*/ + + // Once testing is complete in various modes, replace with this ? + def checkHost(host: String, message: String = "") {} + def checkHostPort(hostPort: String, message: String = "") {} + + // Used by DEBUG code : remove when all testing done + def logErrorWithStack(msg: String) { + try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } + } + + def getUserNameFromEnvironment(): String = { + SparkHadoopUtil.getUserNameFromEnvironment + } + + // Typically, this will be of order of number of nodes in cluster + // If not, we should change it to LRUCache or something. + private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() + + def parseHostPort(hostPort: String): (String, Int) = { + { + // Check cache first. + var cached = hostPortParseResults.get(hostPort) + if (cached != null) return cached + } + + val indx: Int = hostPort.lastIndexOf(':') + // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... + // but then hadoop does not support ipv6 right now. + // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 + if (-1 == indx) { + val retval = (hostPort, 0) + hostPortParseResults.put(hostPort, retval) + return retval + } + + val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt) + hostPortParseResults.putIfAbsent(hostPort, retval) + hostPortParseResults.get(hostPort) } private[spark] val daemonThreadFactory: ThreadFactory = @@ -400,13 +526,45 @@ private object Utils extends Logging { execute(command, new File(".")) } + /** + * Execute a command and get its output, throwing an exception if it yields a code other than 0. + */ + def executeAndGetOutput(command: Seq[String], workingDir: File = new File(".")): String = { + val process = new ProcessBuilder(command: _*) + .directory(workingDir) + .start() + new Thread("read stderr for " + command(0)) { + override def run() { + for (line <- Source.fromInputStream(process.getErrorStream).getLines) { + System.err.println(line) + } + } + }.start() + val output = new StringBuffer + val stdoutThread = new Thread("read stdout for " + command(0)) { + override def run() { + for (line <- Source.fromInputStream(process.getInputStream).getLines) { + output.append(line) + } + } + } + stdoutThread.start() + val exitCode = process.waitFor() + stdoutThread.join() // Wait for it to finish reading output + if (exitCode != 0) { + throw new SparkException("Process " + command + " exited with code " + exitCode) + } + output.toString + } + private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, + val firstUserLine: Int, val firstUserClass: String) /** * When called inside a class in the spark package, returns the name of the user code class * (outside the spark package) that called into Spark, as well as which Spark method they called. * This is used, for example, to tell users where in their code each RDD got created. */ - def getSparkCallSite: String = { + def getCallSiteInfo: CallSiteInfo = { val trace = Thread.currentThread.getStackTrace().filter( el => (!el.getMethodName.contains("getStackTrace"))) @@ -418,6 +576,7 @@ private object Utils extends Logging { var firstUserFile = "<unknown>" var firstUserLine = 0 var finished = false + var firstUserClass = "<unknown>" for (el <- trace) { if (!finished) { @@ -432,13 +591,19 @@ private object Utils extends Logging { else { firstUserLine = el.getLineNumber firstUserFile = el.getFileName + firstUserClass = el.getClassName finished = true } } } - "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine) + new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) } + def formatSparkCallSite = { + val callSiteInfo = getCallSiteInfo + "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile, + callSiteInfo.firstUserLine) + } /** * Try to find a free port to bind to on the local host. This should ideally never be needed, * except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray) @@ -480,4 +645,67 @@ private object Utils extends Logging { } return false } + + def isSpace(c: Char): Boolean = { + " \t\r\n".indexOf(c) != -1 + } + + /** + * Split a string of potentially quoted arguments from the command line the way that a shell + * would do it to determine arguments to a command. For example, if the string is 'a "b c" d', + * then it would be parsed as three arguments: 'a', 'b c' and 'd'. + */ + def splitCommandString(s: String): Seq[String] = { + val buf = new ArrayBuffer[String] + var inWord = false + var inSingleQuote = false + var inDoubleQuote = false + var curWord = new StringBuilder + def endWord() { + buf += curWord.toString + curWord.clear() + } + var i = 0 + while (i < s.length) { + var nextChar = s.charAt(i) + if (inDoubleQuote) { + if (nextChar == '"') { + inDoubleQuote = false + } else if (nextChar == '\\') { + if (i < s.length - 1) { + // Append the next character directly, because only " and \ may be escaped in + // double quotes after the shell's own expansion + curWord.append(s.charAt(i + 1)) + i += 1 + } + } else { + curWord.append(nextChar) + } + } else if (inSingleQuote) { + if (nextChar == '\'') { + inSingleQuote = false + } else { + curWord.append(nextChar) + } + // Backslashes are not treated specially in single quotes + } else if (nextChar == '"') { + inWord = true + inDoubleQuote = true + } else if (nextChar == '\'') { + inWord = true + inSingleQuote = true + } else if (!isSpace(nextChar)) { + curWord.append(nextChar) + inWord = true + } else if (inWord && isSpace(nextChar)) { + endWord() + inWord = false + } + i += 1 + } + if (inWord || inDoubleQuote || inSingleQuote) { + endWord() + } + return buf + } } diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index 89c6d05383..0fa8162f3c 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -7,6 +7,7 @@ import scala.Tuple2 import scala.collection.JavaConversions._ import scala.reflect.ClassTag +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} @@ -460,6 +461,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass) } + /** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */ + def saveAsHadoopFile[F <: OutputFormat[_, _]]( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F], + codec: Class[_ <: CompressionCodec]) { + rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec) + } + /** Output the RDD to any Hadoop-supported file system. */ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( path: String, diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala index 032506383c..6f44e018e9 100644 --- a/core/src/main/scala/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaRDD.scala @@ -17,10 +17,16 @@ JavaRDDLike[T, JavaRDD[T]] { /** * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. Can only be called once on each RDD. + * it is computed. This can only be used to assign a new storage level if the RDD does not + * have a storage level set yet.. */ def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel)) + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + */ + def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist()) + // Transformations (return a new RDD) /** @@ -81,7 +87,6 @@ JavaRDDLike[T, JavaRDD[T]] { */ def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] = wrapRDD(rdd.subtract(other, p)) - } object JavaRDD { diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index a6555081b3..3fe2011f4c 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -1,10 +1,11 @@ package spark.api.java -import java.util.{List => JList} +import java.util.{List => JList, Comparator} import scala.Tuple2 import scala.collection.JavaConversions._ import scala.reflect.ClassTag +import org.apache.hadoop.io.compress.CompressionCodec import spark.{SparkContext, Partition, RDD, TaskContext} import spark.api.java.JavaPairRDD._ import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} @@ -183,6 +184,21 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classTag))(classTag, other.classTag) } + /** + * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by + * applying a function to the zipped partitions. Assumes that all the RDDs have the + * *same number of partitions*, but does *not* require them to have the same number + * of elements in each partition. + */ + def zipPartitions[U, V]( + f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V], + other: JavaRDDLike[U, _]): JavaRDD[V] = { + def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( + f.apply(asJavaIterator(x), asJavaIterator(y)).iterator()) + JavaRDD.fromRDD( + rdd.zipPartitions(fn, other.rdd)(other.classTag, f.elementType()))(f.elementType()) + } + // Actions (launch a job to return a value to the user program) /** @@ -296,6 +312,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def saveAsTextFile(path: String) = rdd.saveAsTextFile(path) + + /** + * Save this RDD as a compressed text file, using string representations of elements. + */ + def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) = + rdd.saveAsTextFile(path, codec) + /** * Save this RDD as a SequenceFile of serialized objects. */ @@ -337,4 +360,29 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def toDebugString(): String = { rdd.toDebugString } + + /** + * Returns the top K elements from this RDD as defined by + * the specified Comparator[T]. + * @param num the number of top elements to return + * @param comp the comparator that defines the order + * @return an array of top elements + */ + def top(num: Int, comp: Comparator[T]): JList[T] = { + import scala.collection.JavaConversions._ + val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp)) + val arr: java.util.Collection[T] = topElems.toSeq + new java.util.ArrayList(arr) + } + + /** + * Returns the top K elements from this RDD using the + * natural ordering for T. + * @param num the number of top elements to return + * @return an array of top elements + */ + def top(num: Int): JList[T] = { + val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]] + top(num, comp) + } } diff --git a/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala new file mode 100644 index 0000000000..6044043add --- /dev/null +++ b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala @@ -0,0 +1,11 @@ +package spark.api.java.function + +/** + * A function that takes two inputs and returns zero or more output records. + */ +abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] { + @throws(classOf[Exception]) + def call(a: A, b:B) : java.lang.Iterable[C] + + def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]] +} diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 220047c360..3d1e45cb2c 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -2,9 +2,10 @@ package spark.api.python import java.io._ import java.net._ -import java.util.{List => JList, ArrayList => JArrayList, Collections} +import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ + import scala.io.Source import scala.reflect.ClassTag @@ -17,16 +18,18 @@ import spark.rdd.PipedRDD private[spark] class PythonRDD[T: ClassTag]( parent: RDD[T], command: Seq[String], - envVars: java.util.Map[String, String], + envVars: JMap[String, String], preservePartitoning: Boolean, pythonExec: String, broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], + def this(parent: RDD[T], command: String, envVars: JMap[String, String], preservePartitoning: Boolean, pythonExec: String, broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]]) = @@ -37,68 +40,57 @@ private[spark] class PythonRDD[T: ClassTag]( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") - - val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py")) - // Add the environmental variables to the process. - val currentEnvVars = pb.environment() - for ((variable, value) <- envVars) { - currentEnvVars.put(variable, value) - } - - val proc = pb.start() + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val startTime = System.currentTimeMillis val env = SparkEnv.get - - // Start a thread to print the process's stderr to ours - new Thread("stderr reader for " + pythonExec) { - override def run() { - for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { - System.err.println(line) - } - } - }.start() + val worker = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + pythonExec) { override def run() { SparkEnv.set(env) - val out = new PrintWriter(proc.getOutputStream) - val dOut = new DataOutputStream(proc.getOutputStream) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + val printOut = new PrintWriter(stream) // Partition index - dOut.writeInt(split.index) + dataOut.writeInt(split.index) // sparkFilesDir - PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut) + PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) // Broadcast variables - dOut.writeInt(broadcastVars.length) + dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { - dOut.writeLong(broadcast.id) - dOut.writeInt(broadcast.value.length) - dOut.write(broadcast.value) - dOut.flush() + dataOut.writeLong(broadcast.id) + dataOut.writeInt(broadcast.value.length) + dataOut.write(broadcast.value) } + dataOut.flush() // Serialized user code for (elem <- command) { - out.println(elem) + printOut.println(elem) } - out.flush() + printOut.flush() // Data values for (elem <- parent.iterator(split, context)) { - PythonRDD.writeAsPickle(elem, dOut) + PythonRDD.writeAsPickle(elem, dataOut) } - dOut.flush() - out.flush() - proc.getOutputStream.close() + dataOut.flush() + printOut.flush() + worker.shutdownOutput() } }.start() // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(proc.getInputStream) + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) return new Iterator[Array[Byte]] { def next(): Array[Byte] = { val obj = _nextObj - _nextObj = read() + if (hasNext) { + // FIXME: can deadlock if worker is waiting for us to + // respond to current message (currently irrelevant because + // output is shutdown before we read any input) + _nextObj = read() + } obj } @@ -109,6 +101,17 @@ private[spark] class PythonRDD[T: ClassTag]( val obj = new Array[Byte](length) stream.readFully(obj) obj + case -3 => + // Timing data from worker + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) + read case -2 => // Signals that an exception has been thrown in python val exLength = stream.readInt() @@ -116,23 +119,21 @@ private[spark] class PythonRDD[T: ClassTag]( stream.readFully(obj) throw new PythonException(new String(obj)) case -1 => - // We've finished the data section of the output, but we can still read some - // accumulator updates; let's do that, breaking when we get EOFException - while (true) { - val len2 = stream.readInt() + // We've finished the data section of the output, but we can still + // read some accumulator updates; let's do that, breaking when we + // get a negative length record. + var len2 = stream.readInt() + while (len2 >= 0) { val update = new Array[Byte](len2) stream.readFully(update) accumulator += Collections.singletonList(update) + len2 = stream.readInt() } new Array[Byte](0) } } catch { case eof: EOFException => { - val exitStatus = proc.waitFor() - if (exitStatus != 0) { - throw new Exception("Subprocess exited with status " + exitStatus) - } - new Array[Byte](0) + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) } case e : Throwable => throw e } @@ -160,7 +161,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends override def compute(split: Partition, context: TaskContext) = prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (a, b) - case x => throw new Exception("PairwiseRDD: unexpected value: " + x) + case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } @@ -216,7 +217,7 @@ private[spark] object PythonRDD { dOut.write(s) dOut.writeByte(Pickle.STOP) } else { - throw new Exception("Unexpected RDD type") + throw new SparkException("Unexpected RDD type") } } @@ -279,6 +280,10 @@ private class BytesToString extends spark.api.java.function.Function[Array[Byte] class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { + Utils.checkHost(serverHost, "Expected hostname") + + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) @@ -291,7 +296,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) // This happens on the master, where we pass the updates to Python through a socket val socket = new Socket(serverHost, serverPort) val in = socket.getInputStream - val out = new DataOutputStream(socket.getOutputStream) + val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) for (array <- val2) { out.writeInt(array.length) diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala new file mode 100644 index 0000000000..85d1dfeac8 --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala @@ -0,0 +1,113 @@ +package spark.api.python + +import java.io.{DataInputStream, IOException} +import java.net.{Socket, SocketException, InetAddress} + +import scala.collection.JavaConversions._ + +import spark._ + +private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) + extends Logging { + var daemon: Process = null + val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) + var daemonPort: Int = 0 + + def create(): Socket = { + synchronized { + // Start the daemon if it hasn't been started + startDaemon() + + // Attempt to connect, restart and retry once if it fails + try { + new Socket(daemonHost, daemonPort) + } catch { + case exc: SocketException => { + logWarning("Python daemon unexpectedly quit, attempting to restart") + stopDaemon() + startDaemon() + new Socket(daemonHost, daemonPort) + } + case e => throw e + } + } + } + + def stop() { + stopDaemon() + } + + private def startDaemon() { + synchronized { + // Is it already running? + if (daemon != null) { + return + } + + try { + // Create and start the daemon + val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") + val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py")) + val workerEnv = pb.environment() + workerEnv.putAll(envVars) + daemon = pb.start() + + // Redirect the stderr to ours + new Thread("stderr reader for " + pythonExec) { + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME HACK: We copy the stream on the level of bytes to + // attempt to dodge encoding problems. + val in = daemon.getErrorStream + var buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() + + val in = new DataInputStream(daemon.getInputStream) + daemonPort = in.readInt() + + // Redirect further stdout output to our stderr + new Thread("stdout reader for " + pythonExec) { + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME HACK: We copy the stream on the level of bytes to + // attempt to dodge encoding problems. + var buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() + } catch { + case e => { + stopDaemon() + throw e + } + } + + // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly + // detect our disappearance. + } + } + + private def stopDaemon() { + synchronized { + // Request shutdown of existing daemon by sending SIGTERM + if (daemon != null) { + daemon.destroy() + } + + daemon = null + daemonPort = 0 + } + } +} diff --git a/core/src/main/scala/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/spark/deploy/ApplicationDescription.scala index 6659e53b25..02193c7008 100644 --- a/core/src/main/scala/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/spark/deploy/ApplicationDescription.scala @@ -2,10 +2,11 @@ package spark.deploy private[spark] class ApplicationDescription( val name: String, - val cores: Int, + val maxCores: Int, /* Integer.MAX_VALUE denotes an unlimited number of cores */ val memoryPerSlave: Int, val command: Command, - val sparkHome: String) + val sparkHome: String, + val appUiUrl: String) extends Serializable { val user = System.getProperty("user.name", "<unknown>") diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 8a3e64e4c2..51274acb1e 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -4,6 +4,7 @@ import spark.deploy.ExecutorState.ExecutorState import spark.deploy.master.{WorkerInfo, ApplicationInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List +import spark.Utils private[spark] sealed trait DeployMessage extends Serializable @@ -19,7 +20,10 @@ case class RegisterWorker( memory: Int, webUiPort: Int, publicAddress: String) - extends DeployMessage + extends DeployMessage { + Utils.checkHost(host, "Required hostname") + assert (port > 0) +} private[spark] case class ExecutorStateChanged( @@ -58,7 +62,9 @@ private[spark] case class RegisteredApplication(appId: String) extends DeployMessage private[spark] -case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) +case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { + Utils.checkHostPort(hostPort, "Required hostport") +} private[spark] case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], @@ -81,6 +87,9 @@ private[spark] case class MasterState(host: String, port: Int, workers: Array[WorkerInfo], activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) { + Utils.checkHost(host, "Required hostname") + assert (port > 0) + def uri = "spark://" + host + ":" + port } @@ -92,4 +101,8 @@ private[spark] case object RequestWorkerState private[spark] case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int, - coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) + coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { + + Utils.checkHost(host, "Required hostname") + assert (port > 0) +} diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala index 702defb628..88b03a007c 100644 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -12,6 +12,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { def write(obj: WorkerInfo) = JsObject( "id" -> JsString(obj.id), "host" -> JsString(obj.host), + "port" -> JsNumber(obj.port), "webuiaddress" -> JsString(obj.webUiAddress), "cores" -> JsNumber(obj.cores), "coresused" -> JsNumber(obj.coresUsed), @@ -25,7 +26,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { "starttime" -> JsNumber(obj.startTime), "id" -> JsString(obj.id), "name" -> JsString(obj.desc.name), - "cores" -> JsNumber(obj.desc.cores), + "cores" -> JsNumber(obj.desc.maxCores), "user" -> JsString(obj.desc.user), "memoryperslave" -> JsNumber(obj.desc.memoryPerSlave), "submitdate" -> JsString(obj.submitDate.toString)) @@ -34,7 +35,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { implicit object AppDescriptionJsonFormat extends RootJsonWriter[ApplicationDescription] { def write(obj: ApplicationDescription) = JsObject( "name" -> JsString(obj.name), - "cores" -> JsNumber(obj.cores), + "cores" -> JsNumber(obj.maxCores), "memoryperslave" -> JsNumber(obj.memoryPerSlave), "user" -> JsString(obj.user) ) diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index 6abaaeaa3f..2b0b3b10e7 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -18,7 +18,7 @@ import scala.collection.mutable.ArrayBuffer private[spark] class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { - private val localIpAddress = Utils.localIpAddress + private val localHostname = Utils.localHostName() private val masterActorSystems = ArrayBuffer[ActorSystem]() private val workerActorSystems = ArrayBuffer[ActorSystem]() @@ -26,13 +26,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ - val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0) + val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0) masterActorSystems += masterSystem - val masterUrl = "spark://" + localIpAddress + ":" + masterPort + val masterUrl = "spark://" + localHostname + ":" + masterPort /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker, + val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, memoryPerWorker, masterUrl, null, Some(workerNum)) workerActorSystems += workerSystem } diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala index a38218a391..690bb20e50 100644 --- a/core/src/main/scala/spark/deploy/client/Client.scala +++ b/core/src/main/scala/spark/deploy/client/Client.scala @@ -4,6 +4,7 @@ import spark.deploy._ import akka.actor._ import akka.pattern.ask import scala.concurrent.duration._ + import akka.pattern.AskTimeoutException import spark.{SparkException, Logging} import akka.remote.RemoteClientLifeCycleEvent @@ -59,10 +60,10 @@ private[spark] class Client( markDisconnected() context.stop(self) - case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) => + case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id - logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores)) - listener.executorAdded(fullId, workerId, host, cores, memory) + logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) + listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => val fullId = appId + "/" + id @@ -112,7 +113,7 @@ private[spark] class Client( def stop() { if (actor != null) { try { - val timeout = 5.seconds + val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") val future = actor.ask(StopClient)(timeout) Await.result(future, timeout) } catch { diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala index b7008321df..e8c4083f9d 100644 --- a/core/src/main/scala/spark/deploy/client/ClientListener.scala +++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala @@ -12,7 +12,7 @@ private[spark] trait ClientListener { def disconnected(): Unit - def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit + def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit } diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index dc004b59ca..f195082808 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -16,7 +16,7 @@ private[spark] object TestClient { System.exit(0) } - def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {} + def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {} def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {} } @@ -25,7 +25,7 @@ private[spark] object TestClient { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val desc = new ApplicationDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home") + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored") val listener = new TestListener val client = new Client(actorSystem, url, desc, listener) client.start() diff --git a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala index 3591a94072..785c16e2be 100644 --- a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala @@ -10,7 +10,8 @@ private[spark] class ApplicationInfo( val id: String, val desc: ApplicationDescription, val submitDate: Date, - val driver: ActorRef) + val driver: ActorRef, + val appUiUrl: String) { var state = ApplicationState.WAITING var executors = new mutable.HashMap[Int, ExecutorInfo] @@ -37,7 +38,7 @@ private[spark] class ApplicationInfo( coresGranted -= exec.cores } - def coresLeft: Int = desc.cores - coresGranted + def coresLeft: Int = desc.maxCores - coresGranted private var _retryCount = 0 @@ -60,4 +61,5 @@ private[spark] class ApplicationInfo( System.currentTimeMillis() - startTime } } + } diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index d1428bcfc6..770cfe9d05 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -15,7 +15,7 @@ import spark.{Logging, SparkException, Utils} import spark.util.AkkaUtils -private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging { +private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000 @@ -35,9 +35,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor var firstApp: Option[ApplicationInfo] = None + Utils.checkHost(host, "Expected hostname") + val masterPublicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else ip + if (envVar != null) envVar else host } // As a temporary workaround before better ways of configuring memory, we allow users to set @@ -46,7 +48,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean override def preStart() { - logInfo("Starting Spark master at spark://" + ip + ":" + port) + logInfo("Starting Spark master at spark://" + host + ":" + port) // Listen for remote client disconnection events, since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) startWebUi() @@ -146,7 +148,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } case RequestMasterState => { - sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray) + sender ! MasterState(host, port, workers.toArray, apps.toArray, completedApps.toArray) } } @@ -212,13 +214,13 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome) - exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) + exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) } def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, publicAddress: String): WorkerInfo = { // There may be one or more refs to dead workers on this same node (w/ different ID's), remove them. - workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _) + workers.filter(w => (w.host == host && w.port == port) && (w.state == WorkerState.DEAD)).foreach(workers -= _) val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) workers += worker idToWorker(worker.id) = worker @@ -243,7 +245,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) - val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver) + val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl) apps += app idToApp(app.id) = app actorToApp(driver) = app @@ -274,6 +276,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor for (exec <- app.executors.values) { exec.worker.removeExecutor(exec) exec.worker.actor ! KillExecutor(exec.application.id, exec.id) + exec.state = ExecutorState.KILLED } app.markFinished(state) app.driver ! ApplicationRemoved(state.toString) @@ -308,7 +311,7 @@ private[spark] object Master { def main(argStrings: Array[String]) { val args = new MasterArguments(argStrings) - val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort) + val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort) actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/spark/deploy/master/MasterArguments.scala index 4ceab3fc03..3d28ecabb4 100644 --- a/core/src/main/scala/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/spark/deploy/master/MasterArguments.scala @@ -7,13 +7,13 @@ import spark.Utils * Command-line parser for the master. */ private[spark] class MasterArguments(args: Array[String]) { - var ip = Utils.localHostName() + var host = Utils.localHostName() var port = 7077 var webUiPort = 8080 // Check for settings in environment variables - if (System.getenv("SPARK_MASTER_IP") != null) { - ip = System.getenv("SPARK_MASTER_IP") + if (System.getenv("SPARK_MASTER_HOST") != null) { + host = System.getenv("SPARK_MASTER_HOST") } if (System.getenv("SPARK_MASTER_PORT") != null) { port = System.getenv("SPARK_MASTER_PORT").toInt @@ -26,7 +26,13 @@ private[spark] class MasterArguments(args: Array[String]) { def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - ip = value + Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + host = value + parse(tail) + + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value parse(tail) case ("--port" | "-p") :: IntParam(value) :: tail => @@ -54,7 +60,8 @@ private[spark] class MasterArguments(args: Array[String]) { "Usage: Master [options]\n" + "\n" + "Options:\n" + - " -i IP, --ip IP IP address or DNS name to listen on\n" + + " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" + + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + " --webui-port PORT Port for web UI (default: 8080)") System.exit(exitCode) diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index fe859d48c3..34cee87853 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -3,6 +3,7 @@ package spark.deploy.master import akka.actor.{ActorRef, ActorContext, ActorRefFactory} import scala.concurrent.Await import akka.pattern.ask + import akka.util.Timeout import scala.concurrent.duration._ import spray.routing.Directives @@ -25,8 +26,7 @@ class MasterWebUI(master: ActorRef)(implicit val context: ActorContext) extends val RESOURCE_DIR = "spark/deploy/master/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(10 seconds) - + implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")) val handler = { get { diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala index 23df1bb463..0c08c5f417 100644 --- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala @@ -2,6 +2,7 @@ package spark.deploy.master import akka.actor.ActorRef import scala.collection.mutable +import spark.Utils private[spark] class WorkerInfo( val id: String, @@ -13,6 +14,9 @@ private[spark] class WorkerInfo( val webUiPort: Int, val publicAddress: String) { + Utils.checkHost(host, "Expected hostname") + assert (port > 0) + var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info var state: WorkerState.Value = WorkerState.ALIVE var coresUsed = 0 @@ -23,6 +27,11 @@ private[spark] class WorkerInfo( def coresFree: Int = cores - coresUsed def memoryFree: Int = memory - memoryUsed + def hostPort: String = { + assert (port > 0) + host + ":" + port + } + def addExecutor(exec: ExecutorInfo) { executors(exec.fullId) = exec coresUsed += exec.cores diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index de11771c8e..d7f58b2cb1 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -1,6 +1,7 @@ package spark.deploy.worker import java.io._ +import java.lang.System.getenv import spark.deploy.{ExecutorState, ExecutorStateChanged, ApplicationDescription} import akka.actor.ActorRef import spark.{Utils, Logging} @@ -21,11 +22,13 @@ private[spark] class ExecutorRunner( val memory: Int, val worker: ActorRef, val workerId: String, - val hostname: String, + val hostPort: String, val sparkHome: File, val workDir: File) extends Logging { + Utils.checkHostPort(hostPort, "Expected hostport") + val fullId = appId + "/" + execId var workerThread: Thread = null var process: Process = null @@ -38,7 +41,7 @@ private[spark] class ExecutorRunner( workerThread.start() // Shutdown hook that kills actors on shutdown. - shutdownHook = new Thread() { + shutdownHook = new Thread() { override def run() { if (process != null) { logInfo("Shutdown hook killing child process.") @@ -68,16 +71,36 @@ private[spark] class ExecutorRunner( /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { case "{{EXECUTOR_ID}}" => execId.toString - case "{{HOSTNAME}}" => hostname + case "{{HOSTNAME}}" => Utils.parseHostPort(hostPort)._1 case "{{CORES}}" => cores.toString case other => other } def buildCommandSeq(): Seq[String] = { val command = appDesc.command - val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run" - val runScript = new File(sparkHome, script).getCanonicalPath - Seq(runScript, command.mainClass) ++ (command.arguments ++ Seq(appId)).map(substituteVariables) + val runner = Option(getenv("JAVA_HOME")).map(_ + "/bin/java").getOrElse("java") + // SPARK-698: do not call the run.cmd script, as process.destroy() + // fails to kill a process tree on Windows + Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++ + command.arguments.map(substituteVariables) + } + + /** + * Attention: this must always be aligned with the environment variables in the run scripts and + * the way the JAVA_OPTS are assembled there. + */ + def buildJavaOpts(): Seq[String] = { + val libraryOpts = Option(getenv("SPARK_LIBRARY_PATH")) + .map(p => List("-Djava.library.path=" + p)) + .getOrElse(Nil) + val userOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil) + val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M") + + // Figure out our classpath with the external compute-classpath script + val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" + val classPath = Utils.executeAndGetOutput(Seq(sparkHome + "/bin/compute-classpath" + ext)) + + Seq("-cp", classPath) ++ libraryOpts ++ userOpts ++ memoryOpts } /** Spawn a thread that will redirect a given stream to a file */ @@ -113,7 +136,6 @@ private[spark] class ExecutorRunner( for ((key, value) <- appDesc.command.environment) { env.put(key, value) } - env.put("SPARK_MEM", memory.toString + "m") // In case we are running this from within the Spark Shell, avoid creating a "scala" // parent process for the executor command env.put("SPARK_LAUNCH_WITH_SCALA", "0") diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 5bcf00443c..b5dfd16e67 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -17,7 +17,7 @@ import java.io.File private[spark] class Worker( - ip: String, + host: String, port: Int, webUiPort: Int, cores: Int, @@ -26,6 +26,9 @@ private[spark] class Worker( workDirPath: String = null) extends Actor with Logging { + Utils.checkHost(host, "Expected hostname") + assert (port > 0) + val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs // Send a heartbeat every (heartbeat timeout) / 4 milliseconds @@ -40,7 +43,7 @@ private[spark] class Worker( val finishedExecutors = new HashMap[String, ExecutorRunner] val publicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else ip + if (envVar != null) envVar else host } var coresUsed = 0 @@ -52,10 +55,14 @@ private[spark] class Worker( def createWorkDir() { workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work")) try { - if (!workDir.exists() && !workDir.mkdirs()) { + // This sporadically fails - not sure why ... !workDir.exists() && !workDir.mkdirs() + // So attempting to create and then check if directory was created or not. + workDir.mkdirs() + if ( !workDir.exists() || !workDir.isDirectory) { logError("Failed to create work directory " + workDir) System.exit(1) } + assert (workDir.isDirectory) } catch { case e: Exception => logError("Failed to create work directory " + workDir, e) @@ -65,7 +72,7 @@ private[spark] class Worker( override def preStart() { logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( - ip, port, cores, Utils.memoryMegabytesToString(memory))) + host, port, cores, Utils.memoryMegabytesToString(memory))) sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse(".")) logInfo("Spark home: " + sparkHome) createWorkDir() @@ -76,7 +83,7 @@ private[spark] class Worker( def connectToMaster() { logInfo("Connecting to master " + masterUrl) master = context.actorFor(Master.toAkkaUrl(masterUrl)) - master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress) + master ! RegisterWorker(workerId, host, port, cores, memory, webUiPort, publicAddress) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing } @@ -108,7 +115,7 @@ private[spark] class Worker( case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) val manager = new ExecutorRunner( - appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir) + appId, execId, appDesc, cores_, memory_, self, workerId, host + ":" + port, new File(execSparkHome_), workDir) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ @@ -143,7 +150,7 @@ private[spark] class Worker( masterDisconnected() case RequestWorkerState => { - sender ! WorkerState(ip, port, workerId, executors.values.toList, + sender ! WorkerState(host, port, workerId, executors.values.toList, finishedExecutors.values.toList, masterUrl, cores, memory, coresUsed, memoryUsed, masterWebUiUrl) } @@ -158,7 +165,7 @@ private[spark] class Worker( } def generateWorkerId(): String = { - "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port) + "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port) } override def postStop() { @@ -169,7 +176,7 @@ private[spark] class Worker( private[spark] object Worker { def main(argStrings: Array[String]) { val args = new WorkerArguments(argStrings) - val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores, + val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, args.memory, args.master, args.workDir) actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala index 08f02bad80..2b96611ee3 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala @@ -9,7 +9,7 @@ import java.lang.management.ManagementFactory * Command-line parser for the master. */ private[spark] class WorkerArguments(args: Array[String]) { - var ip = Utils.localHostName() + var host = Utils.localHostName() var port = 0 var webUiPort = 8081 var cores = inferDefaultCores() @@ -38,7 +38,13 @@ private[spark] class WorkerArguments(args: Array[String]) { def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - ip = value + Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + host = value + parse(tail) + + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value parse(tail) case ("--port" | "-p") :: IntParam(value) :: tail => @@ -93,7 +99,8 @@ private[spark] class WorkerArguments(args: Array[String]) { " -c CORES, --cores CORES Number of cores to use\n" + " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" + " -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" + - " -i IP, --ip IP IP address or DNS name to listen on\n" + + " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" + + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: random)\n" + " --webui-port PORT Port for web UI (default: 8081)") System.exit(exitCode) diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index 33a2a9516e..cc2ab6187a 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -3,6 +3,7 @@ package spark.deploy.worker import akka.actor.{ActorRef, ActorContext} import scala.concurrent.Await import akka.pattern.ask + import akka.util.Timeout import scala.concurrent.duration._ import spray.routing.Directives @@ -25,7 +26,7 @@ class WorkerWebUI(worker: ActorRef, workDir: File)(implicit val context: ActorCo val RESOURCE_DIR = "spark/deploy/worker/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(10 seconds) + implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")) val handler = { get { diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 3e7407b58d..2bf55ea9a9 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -17,7 +17,7 @@ import java.nio.ByteBuffer * The Mesos executor for Spark. */ private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) extends Logging { - + // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() @@ -27,6 +27,11 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert initLogging() + // No ip or host:port - just hostname + Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") + // must not have port specified. + assert (0 == Utils.parseHostPort(slaveHostname)._2) + // Make sure the local hostname we report matches the cluster scheduler's name for this host Utils.setCustomHostname(slaveHostname) @@ -37,7 +42,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert // Create our ClassLoader and set it on this thread private val urlClassLoader = createClassLoader() - Thread.currentThread.setContextClassLoader(urlClassLoader) + private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + Thread.currentThread.setContextClassLoader(replClassLoader) // Make any thread terminations due to uncaught exceptions kill the entire // executor process to avoid surprising stalls. @@ -67,6 +73,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert // Initialize Spark environment (using system properties read above) val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) SparkEnv.set(env) + private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") // Start worker thread pool val threadPool = new ThreadPoolExecutor( @@ -82,7 +89,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert override def run() { val startTime = System.currentTimeMillis() SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(urlClassLoader) + Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo("Running task ID " + taskId) context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) @@ -98,6 +105,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert val value = task.run(taskId.toInt) val taskFinish = System.currentTimeMillis() task.metrics.foreach{ m => + m.hostname = Utils.localHostName m.executorDeserializeTime = (taskStart - startTime).toInt m.executorRunTime = (taskFinish - taskStart).toInt } @@ -108,6 +116,10 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null)) val serializedResult = ser.serialize(result) logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit) + if (serializedResult.limit >= (akkaFrameSize - 1024)) { + context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(TaskResultTooBigFailure())) + return + } context.statusUpdate(taskId, TaskState.FINISHED, serializedResult) logInfo("Finished task ID " + taskId) } catch { @@ -117,7 +129,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert } case t: Throwable => { - val reason = ExceptionFailure(t) + val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace) context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // TODO: Should we exit the whole executor here? On the one hand, the failed task may @@ -142,26 +154,31 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert val urls = currentJars.keySet.map { uri => new File(uri.split("/").last).toURI.toURL }.toArray - loader = new URLClassLoader(urls, loader) + new ExecutorURLClassLoader(urls, loader) + } - // If the REPL is in use, add another ClassLoader that will read - // new classes defined by the REPL as the user types code + /** + * If the REPL is in use, add another ClassLoader that will read + * new classes defined by the REPL as the user types code + */ + private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = { val classUri = System.getProperty("spark.repl.class.uri") if (classUri != null) { logInfo("Using REPL class URI: " + classUri) - loader = { - try { - val klass = Class.forName("spark.repl.ExecutorClassLoader") - .asInstanceOf[Class[_ <: ClassLoader]] - val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) - constructor.newInstance(classUri, loader) - } catch { - case _: ClassNotFoundException => loader - } + try { + val klass = Class.forName("spark.repl.ExecutorClassLoader") + .asInstanceOf[Class[_ <: ClassLoader]] + val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) + return constructor.newInstance(classUri, parent) + } catch { + case _: ClassNotFoundException => + logError("Could not find spark.repl.ExecutorClassLoader on classpath!") + System.exit(1) + null } + } else { + return parent } - - return new ExecutorURLClassLoader(Array(), loader) } /** diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 1047f71c6a..ebe2ac68d8 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -12,23 +12,27 @@ import spark.scheduler.cluster.RegisteredExecutor import spark.scheduler.cluster.LaunchTask import spark.scheduler.cluster.RegisterExecutorFailed import spark.scheduler.cluster.RegisterExecutor +import spark.Utils +import spark.deploy.SparkHadoopUtil private[spark] class StandaloneExecutorBackend( driverUrl: String, executorId: String, - hostname: String, + hostPort: String, cores: Int) extends Actor with ExecutorBackend with Logging { + Utils.checkHostPort(hostPort, "Expected hostport") + var executor: Executor = null var driver: ActorRef = null override def preStart() { logInfo("Connecting to driver: " + driverUrl) driver = context.actorFor(driverUrl) - driver ! RegisterExecutor(executorId, hostname, cores) + driver ! RegisterExecutor(executorId, hostPort, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(driver) // Doesn't work with remote actors, but useful for testing } @@ -36,7 +40,8 @@ private[spark] class StandaloneExecutorBackend( override def receive = { case RegisteredExecutor(sparkProperties) => logInfo("Successfully registered with driver") - executor = new Executor(executorId, hostname, sparkProperties) + // Make this host instead of hostPort ? + executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) @@ -63,11 +68,30 @@ private[spark] class StandaloneExecutorBackend( private[spark] object StandaloneExecutorBackend { def run(driverUrl: String, executorId: String, hostname: String, cores: Int) { + SparkHadoopUtil.runAsUser(run0, Tuple4[Any, Any, Any, Any] (driverUrl, executorId, hostname, cores)) + } + + // This will be run 'as' the user + def run0(args: Product) { + assert(4 == args.productArity) + runImpl(args.productElement(0).asInstanceOf[String], + args.productElement(1).asInstanceOf[String], + args.productElement(2).asInstanceOf[String], + args.productElement(3).asInstanceOf[Int]) + } + + private def runImpl(driverUrl: String, executorId: String, hostname: String, cores: Int) { + // Debug code + Utils.checkHost(hostname) + // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) + // set it + val sparkHostPort = hostname + ":" + boundPort + System.setProperty("spark.hostPort", sparkHostPort) val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(driverUrl, executorId, hostname, cores)), + Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)), name = "Executor") actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala index 93bbb6b458..1dc13754f9 100644 --- a/core/src/main/scala/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/spark/executor/TaskMetrics.scala @@ -2,6 +2,11 @@ package spark.executor class TaskMetrics extends Serializable { /** + * Host's name the task runs on + */ + var hostname: String = _ + + /** * Time taken on the executor to deserialize this task */ var executorDeserializeTime: Int = _ @@ -34,9 +39,14 @@ object TaskMetrics { class ShuffleReadMetrics extends Serializable { /** + * Time when shuffle finishs + */ + var shuffleFinishTime: Long = _ + + /** * Total number of blocks fetched in a shuffle (remote or local) */ - var totalBlocksFetched : Int = _ + var totalBlocksFetched: Int = _ /** * Number of remote blocks fetched in a shuffle @@ -49,11 +59,6 @@ class ShuffleReadMetrics extends Serializable { var localBlocksFetched: Int = _ /** - * Total time to read shuffle data - */ - var shuffleReadMillis: Long = _ - - /** * Total time that is spent blocked waiting for shuffle to fetch data */ var fetchWaitTime: Long = _ diff --git a/core/src/main/scala/spark/network/BufferMessage.scala b/core/src/main/scala/spark/network/BufferMessage.scala new file mode 100644 index 0000000000..7b0e489a6c --- /dev/null +++ b/core/src/main/scala/spark/network/BufferMessage.scala @@ -0,0 +1,94 @@ +package spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +import spark.storage.BlockManager + + +private[spark] +class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) + extends Message(Message.BUFFER_MESSAGE, id_) { + + val initialSize = currentSize() + var gotChunkForSendingOnce = false + + def size = initialSize + + def currentSize() = { + if (buffers == null || buffers.isEmpty) { + 0 + } else { + buffers.map(_.remaining).reduceLeft(_ + _) + } + } + + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { + if (maxChunkSize <= 0) { + throw new Exception("Max chunk size is " + maxChunkSize) + } + + if (size == 0 && gotChunkForSendingOnce == false) { + val newChunk = new MessageChunk( + new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) + gotChunkForSendingOnce = true + return Some(newChunk) + } + + while(!buffers.isEmpty) { + val buffer = buffers(0) + if (buffer.remaining == 0) { + BlockManager.dispose(buffer) + buffers -= buffer + } else { + val newBuffer = if (buffer.remaining <= maxChunkSize) { + buffer.duplicate() + } else { + buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] + } + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + gotChunkForSendingOnce = true + return Some(newChunk) + } + } + None + } + + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { + // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer + if (buffers.size > 1) { + throw new Exception("Attempting to get chunk from message with multiple data buffers") + } + val buffer = buffers(0) + if (buffer.remaining > 0) { + if (buffer.remaining < chunkSize) { + throw new Exception("Not enough space in data buffer for receiving chunk") + } + val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + return Some(newChunk) + } + None + } + + def flip() { + buffers.foreach(_.flip) + } + + def hasAckId() = (ackId != 0) + + def isCompletelyReceived() = !buffers(0).hasRemaining + + override def toString = { + if (hasAckId) { + "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" + } else { + "BufferMessage(id = " + id + ", size = " + size + ")" + } + } +}
\ No newline at end of file diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index d1451bc212..6e28f677a3 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -13,12 +13,13 @@ import java.net._ private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector, - val remoteConnectionManagerId: ConnectionManagerId) extends Logging { + val socketRemoteConnectionManagerId: ConnectionManagerId) + extends Logging { + def this(channel_ : SocketChannel, selector_ : Selector) = { this(channel_, selector_, - ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - )) + ConnectionManagerId.fromSocketAddress( + channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress])) } channel.configureBlocking(false) @@ -33,16 +34,47 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, val remoteAddress = getRemoteAddress() + // Read channels typically do not register for write and write does not for read + // Now, we do have write registering for read too (temporarily), but this is to detect + // channel close NOT to actually read/consume data on it ! + // How does this work if/when we move to SSL ? + + // What is the interest to register with selector for when we want this connection to be selected + def registerInterest() + + // What is the interest to register with selector for when we want this connection to + // be de-selected + // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, + // it will be SelectionKey.OP_READ (until we fix it properly) + def unregisterInterest() + + // On receiving a read event, should we change the interest for this channel or not ? + // Will be true for ReceivingConnection, false for SendingConnection. + def changeInterestForRead(): Boolean + + // On receiving a write event, should we change the interest for this channel or not ? + // Will be false for ReceivingConnection, true for SendingConnection. + // Actually, for now, should not get triggered for ReceivingConnection + def changeInterestForWrite(): Boolean + + def getRemoteConnectionManagerId(): ConnectionManagerId = { + socketRemoteConnectionManagerId + } + def key() = channel.keyFor(selector) def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - def read() { - throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) + // Returns whether we have to register for further reads or not. + def read(): Boolean = { + throw new UnsupportedOperationException( + "Cannot read on connection of type " + this.getClass.toString) } - - def write() { - throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) + + // Returns whether we have to register for further writes or not. + def write(): Boolean = { + throw new UnsupportedOperationException( + "Cannot write on connection of type " + this.getClass.toString) } def close() { @@ -54,26 +86,32 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, callOnCloseCallback() } - def onClose(callback: Connection => Unit) {onCloseCallback = callback} + def onClose(callback: Connection => Unit) { + onCloseCallback = callback + } - def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback} + def onException(callback: (Connection, Exception) => Unit) { + onExceptionCallback = callback + } - def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback} + def onKeyInterestChange(callback: (Connection, Int) => Unit) { + onKeyInterestChangeCallback = callback + } def callOnExceptionCallback(e: Exception) { if (onExceptionCallback != null) { onExceptionCallback(this, e) } else { - logError("Error in connection to " + remoteConnectionManagerId + + logError("Error in connection to " + getRemoteConnectionManagerId() + " and OnExceptionCallback not registered", e) } } - + def callOnCloseCallback() { if (onCloseCallback != null) { onCloseCallback(this) } else { - logWarning("Connection to " + remoteConnectionManagerId + + logWarning("Connection to " + getRemoteConnectionManagerId() + " closed and OnExceptionCallback not registered") } @@ -81,7 +119,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, def changeConnectionKeyInterest(ops: Int) { if (onKeyInterestChangeCallback != null) { - onKeyInterestChangeCallback(this, ops) + onKeyInterestChangeCallback(this, ops) } else { throw new Exception("OnKeyInterestChangeCallback not registered") } @@ -105,24 +143,25 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, print(" (" + position + ", " + length + ")") buffer.position(curPosition) } - } -private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId) -extends Connection(SocketChannel.open, selector_, remoteId_) { +private[spark] +class SendingConnection(val address: InetSocketAddress, selector_ : Selector, + remoteId_ : ConnectionManagerId) + extends Connection(SocketChannel.open, selector_, remoteId_) { class Outbox(fair: Int = 0) { val messages = new Queue[Message]() - val defaultChunkSize = 65536 //32768 //16384 + val defaultChunkSize = 65536 //32768 //16384 var nextMessageToBeUsed = 0 def addMessage(message: Message) { - messages.synchronized{ + messages.synchronized{ /*messages += message*/ messages.enqueue(message) - logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") + logDebug("Added [" + message + "] to outbox for sending to " + + "[" + getRemoteConnectionManagerId() + "]") } } @@ -147,18 +186,18 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { message.started = true message.startTime = System.currentTimeMillis } - return chunk + return chunk } else { - /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ + /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/ message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + + logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "] in " + message.timeTaken ) } } } None } - + private def getChunkRR(): Option[MessageChunk] = { messages.synchronized { while (!messages.isEmpty) { @@ -170,15 +209,17 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { messages.enqueue(message) nextMessageToBeUsed = nextMessageToBeUsed + 1 if (!message.started) { - logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]") + logDebug( + "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") message.started = true message.startTime = System.currentTimeMillis } - logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]") - return chunk + logTrace( + "Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]") + return chunk } else { message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + + logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "] in " + message.timeTaken ) } } @@ -186,27 +227,40 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { None } } - - val outbox = new Outbox(1) + + private val outbox = new Outbox(1) val currentBuffers = new ArrayBuffer[ByteBuffer]() /*channel.socket.setSendBufferSize(256 * 1024)*/ - override def getRemoteAddress() = address + override def getRemoteAddress() = address + + val DEFAULT_INTEREST = SelectionKey.OP_READ + + override def registerInterest() { + // Registering read too - does not really help in most cases, but for some + // it does - so let us keep it for now. + changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) + } + + override def unregisterInterest() { + changeConnectionKeyInterest(DEFAULT_INTEREST) + } def send(message: Message) { outbox.synchronized { outbox.addMessage(message) if (channel.isConnected) { - changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ) + registerInterest() } } } + // MUST be called within the selector loop def connect() { try{ - channel.connect(address) channel.register(selector, SelectionKey.OP_CONNECT) + channel.connect(address) logInfo("Initiating connection to [" + address + "]") } catch { case e: Exception => { @@ -216,36 +270,52 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } - def finishConnect() { + def finishConnect(force: Boolean): Boolean = { try { - channel.finishConnect - changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ) + // Typically, this should finish immediately since it was triggered by a connect + // selection - though need not necessarily always complete successfully. + val connected = channel.finishConnect + if (!force && !connected) { + logInfo( + "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") + return false + } + + // Fallback to previous behavior - assume finishConnect completed + // This will happen only when finishConnect failed for some repeated number of times + // (10 or so) + // Is highly unlikely unless there was an unclean close of socket, etc + registerInterest() logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") + return true } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) callOnExceptionCallback(e) + // ignore + return true } } } - override def write() { - try{ - while(true) { + override def write(): Boolean = { + try { + while (true) { if (currentBuffers.size == 0) { outbox.synchronized { outbox.getChunk() match { case Some(chunk) => { - currentBuffers ++= chunk.buffers + currentBuffers ++= chunk.buffers } case None => { - changeConnectionKeyInterest(SelectionKey.OP_READ) - return + // changeConnectionKeyInterest(0) + /*key.interestOps(0)*/ + return false } } } } - + if (currentBuffers.size > 0) { val buffer = currentBuffers(0) val remainingBytes = buffer.remaining @@ -254,69 +324,109 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { currentBuffers -= buffer } if (writtenBytes < remainingBytes) { - return + // re-register for write. + return true } } } } catch { - case e: Exception => { - logWarning("Error writing in connection to " + remoteConnectionManagerId, e) + case e: Exception => { + logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() + return false } } + // should not happen - to keep scala compiler happy + return true } - override def read() { + // This is a hack to determine if remote socket was closed or not. + // SendingConnection DOES NOT expect to receive any data - if it does, it is an error + // For a bunch of cases, read will return -1 in case remote socket is closed : hence we + // register for reads to determine that. + override def read(): Boolean = { // We don't expect the other side to send anything; so, we just read to detect an error or EOF. try { val length = channel.read(ByteBuffer.allocate(1)) if (length == -1) { // EOF close() } else if (length > 0) { - logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId) + logWarning( + "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) } } catch { case e: Exception => - logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e) + logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() } + + false } + + override def changeInterestForRead(): Boolean = false + + override def changeInterestForWrite(): Boolean = true } -private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) -extends Connection(channel_, selector_) { - +// Must be created within selector loop - else deadlock +private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) + extends Connection(channel_, selector_) { + class Inbox() { val messages = new HashMap[Int, BufferMessage]() - + def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { - + def createNewMessage: BufferMessage = { val newMessage = Message.create(header).asInstanceOf[BufferMessage] newMessage.started = true newMessage.startTime = System.currentTimeMillis - logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]") + logDebug( + "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") messages += ((newMessage.id, newMessage)) newMessage } - + val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]") + logTrace( + "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") message.getChunkForReceiving(header.chunkSize) } - + def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { - messages.get(chunk.header.id) + messages.get(chunk.header.id) } def removeMessage(message: Message) { messages -= message.id } } - + + @volatile private var inferredRemoteManagerId: ConnectionManagerId = null + + override def getRemoteConnectionManagerId(): ConnectionManagerId = { + val currId = inferredRemoteManagerId + if (currId != null) currId else super.getRemoteConnectionManagerId() + } + + // The reciever's remote address is the local socket on remote side : which is NOT + // the connection manager id of the receiver. + // We infer that from the messages we receive on the receiver socket. + private def processConnectionManagerId(header: MessageChunkHeader) { + val currId = inferredRemoteManagerId + if (header.address == null || currId != null) return + + val managerId = ConnectionManagerId.fromSocketAddress(header.address) + + if (managerId != null) { + inferredRemoteManagerId = managerId + } + } + + val inbox = new Inbox() val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) var onReceiveCallback: (Connection , Message) => Unit = null @@ -324,24 +434,29 @@ extends Connection(channel_, selector_) { channel.register(selector, SelectionKey.OP_READ) - override def read() { + override def read(): Boolean = { try { while (true) { if (currentChunk == null) { val headerBytesRead = channel.read(headerBuffer) if (headerBytesRead == -1) { close() - return + return false } if (headerBuffer.remaining > 0) { - return + // re-register for read event ... + return true } headerBuffer.flip if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { - throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") + throw new Exception( + "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") } val header = MessageChunkHeader.create(headerBuffer) headerBuffer.clear() + + processConnectionManagerId(header) + header.typ match { case Message.BUFFER_MESSAGE => { if (header.totalSize == 0) { @@ -349,7 +464,8 @@ extends Connection(channel_, selector_) { onReceiveCallback(this, Message.create(header)) } currentChunk = null - return + // re-register for read event ... + return true } else { currentChunk = inbox.getChunk(header).orNull } @@ -357,26 +473,28 @@ extends Connection(channel_, selector_) { case _ => throw new Exception("Message of unknown type received") } } - + if (currentChunk == null) throw new Exception("No message chunk to receive data") - + val bytesRead = channel.read(currentChunk.buffer) if (bytesRead == 0) { - return + // re-register for read event ... + return true } else if (bytesRead == -1) { close() - return + return false } /*logDebug("Read " + bytesRead + " bytes for the buffer")*/ - + if (currentChunk.buffer.remaining == 0) { /*println("Filled buffer at " + System.currentTimeMillis)*/ val bufferMessage = inbox.getMessageForChunk(currentChunk).get if (bufferMessage.isCompletelyReceived) { bufferMessage.flip bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken) + logDebug("Finished receiving [" + bufferMessage + "] from " + + "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) if (onReceiveCallback != null) { onReceiveCallback(this, bufferMessage) } @@ -386,13 +504,32 @@ extends Connection(channel_, selector_) { } } } catch { - case e: Exception => { - logWarning("Error reading from connection to " + remoteConnectionManagerId, e) + case e: Exception => { + logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() + return false } } + // should not happen - to keep scala compiler happy + return true } - + def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} + + override def changeInterestForRead(): Boolean = true + + override def changeInterestForWrite(): Boolean = { + throw new IllegalStateException("Unexpected invocation right now") + } + + override def registerInterest() { + // Registering read too - does not really help in most cases, but for some + // it does - so let us keep it for now. + changeConnectionKeyInterest(SelectionKey.OP_READ) + } + + override def unregisterInterest() { + changeConnectionKeyInterest(0) + } } diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 8f8892b8c7..cc5c62a542 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -6,28 +6,19 @@ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.net._ -import java.util.concurrent.Executors +import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} +import scala.collection.mutable.HashSet import scala.collection.mutable.HashMap import scala.collection.mutable.SynchronizedMap import scala.collection.mutable.SynchronizedQueue -import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, Promise, ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.concurrent.duration._ -private[spark] case class ConnectionManagerId(host: String, port: Int) { - def toSocketAddress() = new InetSocketAddress(host, port) -} -private[spark] object ConnectionManagerId { - def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { - new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) - } -} - private[spark] class ConnectionManager(port: Int) extends Logging { class MessageStatus( @@ -41,73 +32,263 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def markDone() { completionHandler(this) } } - - val selector = SelectorProvider.provider.openSelector() - val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt) - val serverChannel = ServerSocketChannel.open() - val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] - val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - val messageStatuses = new HashMap[Int, MessageStatus] - val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - val sendMessageRequests = new Queue[(Message, SendingConnection)] + + private val selector = SelectorProvider.provider.openSelector() + + private val handleMessageExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.handler.threads.min","20").toInt, + System.getProperty("spark.core.connection.handler.threads.max","60").toInt, + System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + private val handleReadWriteExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.io.threads.min","4").toInt, + System.getProperty("spark.core.connection.io.threads.max","32").toInt, + System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap + private val handleConnectExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.connect.threads.min","1").toInt, + System.getProperty("spark.core.connection.connect.threads.max","8").toInt, + System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + private val serverChannel = ServerSocketChannel.open() + private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] + private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] + private val messageStatuses = new HashMap[Int, MessageStatus] + private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] + private val registerRequests = new SynchronizedQueue[SendingConnection] implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) - var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null + private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null serverChannel.configureBlocking(false) serverChannel.socket.setReuseAddress(true) - serverChannel.socket.setReceiveBufferSize(256 * 1024) + serverChannel.socket.setReceiveBufferSize(256 * 1024) serverChannel.socket.bind(new InetSocketAddress(port)) serverChannel.register(selector, SelectionKey.OP_ACCEPT) val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) - - val selectorThread = new Thread("connection-manager-thread") { + + private val selectorThread = new Thread("connection-manager-thread") { override def run() = ConnectionManager.this.run() } selectorThread.setDaemon(true) selectorThread.start() - private def run() { - try { - while(!selectorThread.isInterrupted) { - for ((connectionManagerId, sendingConnection) <- connectionRequests) { - sendingConnection.connect() - addConnection(sendingConnection) - connectionRequests -= connectionManagerId + private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + + private def triggerWrite(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + writeRunnableStarted.synchronized { + // So that we do not trigger more write events while processing this one. + // The write method will re-register when done. + if (conn.changeInterestForWrite()) conn.unregisterInterest() + if (writeRunnableStarted.contains(key)) { + // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE) + return + } + + writeRunnableStarted += key + } + handleReadWriteExecutor.execute(new Runnable { + override def run() { + var register: Boolean = false + try { + register = conn.write() + } finally { + writeRunnableStarted.synchronized { + writeRunnableStarted -= key + if (register && conn.changeInterestForWrite()) { + conn.registerInterest() + } + } } - sendMessageRequests.synchronized { - while (!sendMessageRequests.isEmpty) { - val (message, connection) = sendMessageRequests.dequeue - connection.send(message) + } + } ) + } + + private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + + private def triggerRead(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + readRunnableStarted.synchronized { + // So that we do not trigger more read events while processing this one. + // The read method will re-register when done. + if (conn.changeInterestForRead())conn.unregisterInterest() + if (readRunnableStarted.contains(key)) { + return + } + + readRunnableStarted += key + } + handleReadWriteExecutor.execute(new Runnable { + override def run() { + var register: Boolean = false + try { + register = conn.read() + } finally { + readRunnableStarted.synchronized { + readRunnableStarted -= key + if (register && conn.changeInterestForRead()) { + conn.registerInterest() + } } } + } + } ) + } + + private def triggerConnect(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection] + if (conn == null) return + + // prevent other events from being triggered + // Since we are still trying to connect, we do not need to do the additional steps in triggerWrite + conn.changeConnectionKeyInterest(0) + + handleConnectExecutor.execute(new Runnable { + override def run() { + + var tries: Int = 10 + while (tries >= 0) { + if (conn.finishConnect(false)) return + // Sleep ? + Thread.sleep(1) + tries -= 1 + } + + // fallback to previous behavior : we should not really come here since this method was + // triggered since channel became connectable : but at times, the first finishConnect need not + // succeed : hence the loop to retry a few 'times'. + conn.finishConnect(true) + } + } ) + } + + // MUST be called within selector loop - else deadlock. + private def triggerForceCloseByException(key: SelectionKey, e: Exception) { + try { + key.interestOps(0) + } catch { + // ignore exceptions + case e: Exception => logDebug("Ignoring exception", e) + } + + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + // Pushing to connect threadpool + handleConnectExecutor.execute(new Runnable { + override def run() { + try { + conn.callOnExceptionCallback(e) + } catch { + // ignore exceptions + case e: Exception => logDebug("Ignoring exception", e) + } + try { + conn.close() + } catch { + // ignore exceptions + case e: Exception => logDebug("Ignoring exception", e) + } + } + }) + } + - while (!keyInterestChangeRequests.isEmpty) { + def run() { + try { + while(!selectorThread.isInterrupted) { + while (! registerRequests.isEmpty) { + val conn: SendingConnection = registerRequests.dequeue + addListeners(conn) + conn.connect() + addConnection(conn) + } + + while(!keyInterestChangeRequests.isEmpty) { val (key, ops) = keyInterestChangeRequests.dequeue - val connection = connectionsByKey(key) - val lastOps = key.interestOps() - key.interestOps(ops) - - def intToOpStr(op: Int): String = { - val opStrs = ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + + try { + if (key.isValid) { + val connection = connectionsByKey.getOrElse(key, null) + if (connection != null) { + val lastOps = key.interestOps() + key.interestOps(ops) + + // hot loop - prevent materialization of string if trace not enabled. + if (isTraceEnabled()) { + def intToOpStr(op: Int): String = { + val opStrs = ArrayBuffer[String]() + if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" + if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" + if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" + if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" + if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + } + + logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() + + "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") + } + } + } else { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() + } + } catch { + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) + } + case e: Exception => { + logError("Exception processing key " + key, e) + triggerForceCloseByException(key, e) + } } - - logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId + - "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } - val selectedKeysCount = selector.select() + val selectedKeysCount = + try { + selector.select() + } catch { + // Explicitly only dealing with CancelledKeyException here since other exceptions should be dealt with differently. + case e: CancelledKeyException => { + // Some keys within the selectors list are invalid/closed. clear them. + val allKeys = selector.keys().iterator() + + while (allKeys.hasNext()) { + val key = allKeys.next() + try { + if (! key.isValid) { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() + } + } catch { + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) + } + case e: Exception => { + logError("Exception processing key " + key, e) + triggerForceCloseByException(key, e) + } + } + } + } + 0 + } + if (selectedKeysCount == 0) { logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys") } @@ -115,20 +296,40 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logInfo("Selector thread was interrupted!") return } - - val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext()) { - val key = selectedKeys.next - selectedKeys.remove() - if (key.isValid) { - if (key.isAcceptable) { - acceptConnection(key) - } else if (key.isConnectable) { - connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect() - } else if (key.isReadable) { - connectionsByKey(key).read() - } else if (key.isWritable) { - connectionsByKey(key).write() + + if (0 != selectedKeysCount) { + val selectedKeys = selector.selectedKeys().iterator() + while (selectedKeys.hasNext()) { + val key = selectedKeys.next + selectedKeys.remove() + try { + if (key.isValid) { + if (key.isAcceptable) { + acceptConnection(key) + } else + if (key.isConnectable) { + triggerConnect(key) + } else + if (key.isReadable) { + triggerRead(key) + } else + if (key.isWritable) { + triggerWrite(key) + } + } else { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() + } + } catch { + // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException. + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) + } + case e: Exception => { + logError("Exception processing key " + key, e) + triggerForceCloseByException(key, e) + } } } } @@ -137,97 +338,119 @@ private[spark] class ConnectionManager(port: Int) extends Logging { case e: Exception => logError("Error in select loop", e) } } - - private def acceptConnection(key: SelectionKey) { + + def acceptConnection(key: SelectionKey) { val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - val newChannel = serverChannel.accept() - val newConnection = new ReceivingConnection(newChannel, selector) - newConnection.onReceive(receiveMessage) - newConnection.onClose(removeConnection) - addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") - } - private def addConnection(connection: Connection) { - connectionsByKey += ((connection.key, connection)) - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection)) + var newChannel = serverChannel.accept() + + // accept them all in a tight loop. non blocking accept with no processing, should be fine + while (newChannel != null) { + try { + val newConnection = new ReceivingConnection(newChannel, selector) + newConnection.onReceive(receiveMessage) + addListeners(newConnection) + addConnection(newConnection) + logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") + } catch { + // might happen in case of issues with registering with selector + case e: Exception => logError("Error in accept loop", e) + } + + newChannel = serverChannel.accept() } + } + + private def addListeners(connection: Connection) { connection.onKeyInterestChange(changeConnectionKeyInterest) connection.onException(handleConnectionError) connection.onClose(removeConnection) } - private def removeConnection(connection: Connection) { + def addConnection(connection: Connection) { + connectionsByKey += ((connection.key, connection)) + } + + def removeConnection(connection: Connection) { connectionsByKey -= connection.key - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - - messageStatuses.synchronized { - messageStatuses - .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { - logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.markDone() - } + + try { + if (connection.isInstanceOf[SendingConnection]) { + val sendingConnection = connection.asInstanceOf[SendingConnection] + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + logInfo("Removing SendingConnection to " + sendingConnectionManagerId) + + connectionsById -= sendingConnectionManagerId + + messageStatuses.synchronized { + messageStatuses + .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { + logInfo("Notifying " + status) + status.synchronized { + status.attempted = true + status.acked = false + status.markDone() + } + }) + + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId }) + } + } else if (connection.isInstanceOf[ReceivingConnection]) { + val receivingConnection = connection.asInstanceOf[ReceivingConnection] + val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() + logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) + + val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) + if (! sendingConnectionOpt.isDefined) { + logError("Corresponding SendingConnectionManagerId not found") + return + } - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - } else if (connection.isInstanceOf[ReceivingConnection]) { - val receivingConnection = connection.asInstanceOf[ReceivingConnection] - val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull - if (sendingConnectionManagerId == null) { - logError("Corresponding SendingConnectionManagerId not found") - return - } - logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId) - - val sendingConnection = connectionsById(sendingConnectionManagerId) - sendingConnection.close() - connectionsById -= sendingConnectionManagerId - - messageStatuses.synchronized { - for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { - logInfo("Notifying " + s) - s.synchronized { - s.attempted = true - s.acked = false - s.markDone() + val sendingConnection = sendingConnectionOpt.get + connectionsById -= remoteConnectionManagerId + sendingConnection.close() + + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + + assert (sendingConnectionManagerId == remoteConnectionManagerId) + + messageStatuses.synchronized { + for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { + logInfo("Notifying " + s) + s.synchronized { + s.attempted = true + s.acked = false + s.markDone() + } } - } - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId + }) + } } + } finally { + // So that the selection keys can be removed. + wakeupSelector() } } - private def handleConnectionError(connection: Connection, e: Exception) { - logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId) + def handleConnectionError(connection: Connection, e: Exception) { + logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) removeConnection(connection) } - private def changeConnectionKeyInterest(connection: Connection, ops: Int) { - keyInterestChangeRequests += ((connection.key, ops)) + def changeConnectionKeyInterest(connection: Connection, ops: Int) { + keyInterestChangeRequests += ((connection.key, ops)) + // so that registerations happen ! + wakeupSelector() } - private def receiveMessage(connection: Connection, message: Message) { + def receiveMessage(connection: Connection, message: Message) { val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) - logDebug("Received [" + message + "] from [" + connectionManagerId + "]") + logDebug("Received [" + message + "] from [" + connectionManagerId + "]") val runnable = new Runnable() { val creationTime = System.currentTimeMillis def run() { @@ -247,11 +470,11 @@ private[spark] class ConnectionManager(port: Int) extends Logging { if (bufferMessage.hasAckId) { val sentMessageStatus = messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { - case Some(status) => { - messageStatuses -= bufferMessage.ackId + case Some(status) => { + messageStatuses -= bufferMessage.ackId status } - case None => { + case None => { throw new Exception("Could not find reference for received ack message " + message.id) null } @@ -271,7 +494,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logDebug("Not calling back as callback is null") None } - + if (ackMessage.isDefined) { if (!ackMessage.get.isInstanceOf[BufferMessage]) { logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass()) @@ -281,7 +504,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } } - sendMessage(connectionManagerId, ackMessage.getOrElse { + sendMessage(connectionManagerId, ackMessage.getOrElse { Message.createBufferMessage(bufferMessage.id) }) } @@ -293,18 +516,22 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, - new SendingConnection(inetSocketAddress, selector, connectionManagerId)) - newConnection + val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId) + registerRequests.enqueue(newConnection) + + newConnection } - val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) - val connection = connectionsById.getOrElse(lookupKey, startNewConnection()) + // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ... + // If we do re-add it, we should consistently use it everywhere I guess ? + val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) message.senderAddress = id.toSocketAddress() logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") - /*connection.send(message)*/ - sendMessageRequests.synchronized { - sendMessageRequests += ((message, connection)) - } + connection.send(message) + + wakeupSelector() + } + + private def wakeupSelector() { selector.wakeup() } @@ -337,6 +564,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logWarning("All connections not cleaned up") } handleMessageExecutor.shutdown() + handleReadWriteExecutor.shutdown() + handleConnectExecutor.shutdown() logInfo("ConnectionManager stopped") } } @@ -346,17 +575,17 @@ private[spark] object ConnectionManager { def main(args: Array[String]) { val manager = new ConnectionManager(9999) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") None }) - + /*testSequentialSending(manager)*/ /*System.gc()*/ /*testParallelSending(manager)*/ /*System.gc()*/ - + /*testParallelDecreasingSending(manager)*/ /*System.gc()*/ @@ -368,9 +597,9 @@ private[spark] object ConnectionManager { println("--------------------------") println("Sequential Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 - + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip @@ -386,7 +615,7 @@ private[spark] object ConnectionManager { println("--------------------------") println("Parallel Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) @@ -401,12 +630,12 @@ private[spark] object ConnectionManager { if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis - + val mb = size * count / 1024.0 / 1024.0 val ms = finishTime - startTime val tput = mb * 1000.0 / ms println("--------------------------") - println("Started at " + startTime + ", finished at " + finishTime) + println("Started at " + startTime + ", finished at " + finishTime) println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)") println("--------------------------") println() @@ -416,7 +645,7 @@ private[spark] object ConnectionManager { println("--------------------------") println("Parallel Decreasing Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte))) buffers.foreach(_.flip) @@ -431,7 +660,7 @@ private[spark] object ConnectionManager { if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis - + val ms = finishTime - startTime val tput = mb * 1000.0 / ms println("--------------------------") @@ -445,7 +674,7 @@ private[spark] object ConnectionManager { println("--------------------------") println("Continuous Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) diff --git a/core/src/main/scala/spark/network/ConnectionManagerId.scala b/core/src/main/scala/spark/network/ConnectionManagerId.scala new file mode 100644 index 0000000000..b554e84251 --- /dev/null +++ b/core/src/main/scala/spark/network/ConnectionManagerId.scala @@ -0,0 +1,21 @@ +package spark.network + +import java.net.InetSocketAddress + +import spark.Utils + + +private[spark] case class ConnectionManagerId(host: String, port: Int) { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + + def toSocketAddress() = new InetSocketAddress(host, port) +} + + +private[spark] object ConnectionManagerId { + def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { + new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) + } +} diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala index 525751b5bf..d4f03610eb 100644 --- a/core/src/main/scala/spark/network/Message.scala +++ b/core/src/main/scala/spark/network/Message.scala @@ -1,55 +1,10 @@ package spark.network -import spark._ - -import scala.collection.mutable.ArrayBuffer - import java.nio.ByteBuffer -import java.net.InetAddress import java.net.InetSocketAddress -import storage.BlockManager - -private[spark] class MessageChunkHeader( - val typ: Long, - val id: Int, - val totalSize: Int, - val chunkSize: Int, - val other: Int, - val address: InetSocketAddress) { - lazy val buffer = { - val ip = address.getAddress.getAddress() - val port = address.getPort() - ByteBuffer. - allocate(MessageChunkHeader.HEADER_SIZE). - putLong(typ). - putInt(id). - putInt(totalSize). - putInt(chunkSize). - putInt(other). - putInt(ip.size). - put(ip). - putInt(port). - position(MessageChunkHeader.HEADER_SIZE). - flip.asInstanceOf[ByteBuffer] - } - - override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes" -} -private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - val size = if (buffer == null) 0 else buffer.remaining - lazy val buffers = { - val ab = new ArrayBuffer[ByteBuffer]() - ab += header.buffer - if (buffer != null) { - ab += buffer - } - ab - } +import scala.collection.mutable.ArrayBuffer - override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" -} private[spark] abstract class Message(val typ: Long, val id: Int) { var senderAddress: InetSocketAddress = null @@ -58,120 +13,16 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { var finishTime = -1L def size: Int - + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] - + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] - + def timeTaken(): String = (finishTime - startTime).toString + " ms" override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" } -private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) -extends Message(Message.BUFFER_MESSAGE, id_) { - - val initialSize = currentSize() - var gotChunkForSendingOnce = false - - def size = initialSize - - def currentSize() = { - if (buffers == null || buffers.isEmpty) { - 0 - } else { - buffers.map(_.remaining).reduceLeft(_ + _) - } - } - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { - if (maxChunkSize <= 0) { - throw new Exception("Max chunk size is " + maxChunkSize) - } - - if (size == 0 && gotChunkForSendingOnce == false) { - val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) - gotChunkForSendingOnce = true - return Some(newChunk) - } - - while(!buffers.isEmpty) { - val buffer = buffers(0) - if (buffer.remaining == 0) { - BlockManager.dispose(buffer) - buffers -= buffer - } else { - val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate() - } else { - buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] - } - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - gotChunkForSendingOnce = true - return Some(newChunk) - } - } - None - } - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { - // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer - if (buffers.size > 1) { - throw new Exception("Attempting to get chunk from message with multiple data buffers") - } - val buffer = buffers(0) - if (buffer.remaining > 0) { - if (buffer.remaining < chunkSize) { - throw new Exception("Not enough space in data buffer for receiving chunk") - } - val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - return Some(newChunk) - } - None - } - - def flip() { - buffers.foreach(_.flip) - } - - def hasAckId() = (ackId != 0) - - def isCompletelyReceived() = !buffers(0).hasRemaining - - override def toString = { - if (hasAckId) { - "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" - } else { - "BufferMessage(id = " + id + ", size = " + size + ")" - } - } -} - -private[spark] object MessageChunkHeader { - val HEADER_SIZE = 40 - - def create(buffer: ByteBuffer): MessageChunkHeader = { - if (buffer.remaining != HEADER_SIZE) { - throw new IllegalArgumentException("Cannot convert buffer data to Message") - } - val typ = buffer.getLong() - val id = buffer.getInt() - val totalSize = buffer.getInt() - val chunkSize = buffer.getInt() - val other = buffer.getInt() - val ipSize = buffer.getInt() - val ipBytes = new Array[Byte](ipSize) - buffer.get(ipBytes) - val ip = InetAddress.getByAddress(ipBytes) - val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) - } -} private[spark] object Message { val BUFFER_MESSAGE = 1111111111L @@ -180,14 +31,16 @@ private[spark] object Message { def getNewId() = synchronized { lastId += 1 - if (lastId == 0) lastId += 1 + if (lastId == 0) { + lastId += 1 + } lastId } def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { if (dataBuffers == null) { return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) - } + } if (dataBuffers.exists(_ == null)) { throw new Exception("Attempting to create buffer message with null buffer") } @@ -196,7 +49,7 @@ private[spark] object Message { def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = createBufferMessage(dataBuffers, 0) - + def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { if (dataBuffer == null) { return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) @@ -204,15 +57,18 @@ private[spark] object Message { return createBufferMessage(Array(dataBuffer), ackId) } } - - def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = + + def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = createBufferMessage(dataBuffer, 0) - - def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId) + + def createBufferMessage(ackId: Int): BufferMessage = { + createBufferMessage(new Array[ByteBuffer](0), ackId) + } def create(header: MessageChunkHeader): Message = { val newMessage: Message = header.typ match { - case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) + case BUFFER_MESSAGE => new BufferMessage(header.id, + ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) } newMessage.senderAddress = header.address newMessage diff --git a/core/src/main/scala/spark/network/MessageChunk.scala b/core/src/main/scala/spark/network/MessageChunk.scala new file mode 100644 index 0000000000..aaf9204d0e --- /dev/null +++ b/core/src/main/scala/spark/network/MessageChunk.scala @@ -0,0 +1,25 @@ +package spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + + +private[network] +class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { + + val size = if (buffer == null) 0 else buffer.remaining + + lazy val buffers = { + val ab = new ArrayBuffer[ByteBuffer]() + ab += header.buffer + if (buffer != null) { + ab += buffer + } + ab + } + + override def toString = { + "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" + } +} diff --git a/core/src/main/scala/spark/network/MessageChunkHeader.scala b/core/src/main/scala/spark/network/MessageChunkHeader.scala new file mode 100644 index 0000000000..3693d509d6 --- /dev/null +++ b/core/src/main/scala/spark/network/MessageChunkHeader.scala @@ -0,0 +1,58 @@ +package spark.network + +import java.net.InetAddress +import java.net.InetSocketAddress +import java.nio.ByteBuffer + + +private[spark] class MessageChunkHeader( + val typ: Long, + val id: Int, + val totalSize: Int, + val chunkSize: Int, + val other: Int, + val address: InetSocketAddress) { + lazy val buffer = { + // No need to change this, at 'use' time, we do a reverse lookup of the hostname. + // Refer to network.Connection + val ip = address.getAddress.getAddress() + val port = address.getPort() + ByteBuffer. + allocate(MessageChunkHeader.HEADER_SIZE). + putLong(typ). + putInt(id). + putInt(totalSize). + putInt(chunkSize). + putInt(other). + putInt(ip.size). + put(ip). + putInt(port). + position(MessageChunkHeader.HEADER_SIZE). + flip.asInstanceOf[ByteBuffer] + } + + override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + + " and sizes " + totalSize + " / " + chunkSize + " bytes" +} + + +private[spark] object MessageChunkHeader { + val HEADER_SIZE = 40 + + def create(buffer: ByteBuffer): MessageChunkHeader = { + if (buffer.remaining != HEADER_SIZE) { + throw new IllegalArgumentException("Cannot convert buffer data to Message") + } + val typ = buffer.getLong() + val id = buffer.getInt() + val totalSize = buffer.getInt() + val chunkSize = buffer.getInt() + val other = buffer.getInt() + val ipSize = buffer.getInt() + val ipBytes = new Array[Byte](ipSize) + buffer.get(ipBytes) + val ip = InetAddress.getByAddress(ipBytes) + val port = buffer.getInt() + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) + } +} diff --git a/core/src/main/scala/spark/network/netty/FileHeader.scala b/core/src/main/scala/spark/network/netty/FileHeader.scala new file mode 100644 index 0000000000..aed4254234 --- /dev/null +++ b/core/src/main/scala/spark/network/netty/FileHeader.scala @@ -0,0 +1,57 @@ +package spark.network.netty + +import io.netty.buffer._ + +import spark.Logging + +private[spark] class FileHeader ( + val fileLen: Int, + val blockId: String) extends Logging { + + lazy val buffer = { + val buf = Unpooled.buffer() + buf.capacity(FileHeader.HEADER_SIZE) + buf.writeInt(fileLen) + buf.writeInt(blockId.length) + blockId.foreach((x: Char) => buf.writeByte(x)) + //padding the rest of header + if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { + buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) + } else { + throw new Exception("too long header " + buf.readableBytes) + logInfo("too long header") + } + buf + } + +} + +private[spark] object FileHeader { + + val HEADER_SIZE = 40 + + def getFileLenOffset = 0 + def getFileLenSize = Integer.SIZE/8 + + def create(buf: ByteBuf): FileHeader = { + val length = buf.readInt + val idLength = buf.readInt + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buf.readByte().asInstanceOf[Char] + } + val blockId = idBuilder.toString() + new FileHeader(length, blockId) + } + + + def main (args:Array[String]){ + + val header = new FileHeader(25,"block_0"); + val buf = header.buffer; + val newheader = FileHeader.create(buf); + System.out.println("id="+newheader.blockId+",size="+newheader.fileLen) + + } +} + diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala new file mode 100644 index 0000000000..8d5194a737 --- /dev/null +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -0,0 +1,101 @@ +package spark.network.netty + +import java.util.concurrent.Executors + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.util.CharsetUtil + +import spark.Logging +import spark.network.ConnectionManagerId + +import scala.collection.JavaConverters._ + + +private[spark] class ShuffleCopier extends Logging { + + def getBlock(host: String, port: Int, blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) { + + val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) + val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt + val fc = new FileClient(handler, connectTimeout) + + try { + fc.init() + fc.connect(host, port) + fc.sendRequest(blockId) + fc.waitForClose() + fc.close() + } catch { + // Handle any socket-related exceptions in FileClient + case e: Exception => { + logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e) + handler.handleError(blockId) + } + } + } + + def getBlock(cmId: ConnectionManagerId, blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) { + getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) + } + + def getBlocks(cmId: ConnectionManagerId, + blocks: Seq[(String, Long)], + resultCollectCallback: (String, Long, ByteBuf) => Unit) { + + for ((blockId, size) <- blocks) { + getBlock(cmId, blockId, resultCollectCallback) + } + } +} + + +private[spark] object ShuffleCopier extends Logging { + + private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit) + extends FileClientHandler with Logging { + + override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { + logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); + resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) + } + + override def handleError(blockId: String) { + if (!isComplete) { + resultCollectCallBack(blockId, -1, null) + } + } + } + + def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { + if (size != -1) { + logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") + } + } + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>") + System.exit(1) + } + val host = args(0) + val port = args(1).toInt + val file = args(2) + val threads = if (args.length > 3) args(3).toInt else 10 + + val copiers = Executors.newFixedThreadPool(80) + val tasks = (for (i <- Range(0, threads)) yield { + Executors.callable(new Runnable() { + def run() { + val copier = new ShuffleCopier() + copier.getBlock(host, port, file, echoResultCollectCallBack) + } + }) + }).asJava + copiers.invokeAll(tasks) + copiers.shutdown + System.exit(0) + } +} diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala new file mode 100644 index 0000000000..d6fa4b1e80 --- /dev/null +++ b/core/src/main/scala/spark/network/netty/ShuffleSender.scala @@ -0,0 +1,53 @@ +package spark.network.netty + +import java.io.File + +import spark.Logging + + +private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { + + val server = new FileServer(pResolver, portIn) + server.start() + + def stop() { + server.stop() + } + + def port: Int = server.getPort() +} + + +/** + * An application for testing the shuffle sender as a standalone program. + */ +private[spark] object ShuffleSender { + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println( + "Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>") + System.exit(1) + } + + val port = args(0).toInt + val subDirsPerLocalDir = args(1).toInt + val localDirs = args.drop(2).map(new File(_)) + + val pResovler = new PathResolver { + override def getAbsolutePath(blockId: String): String = { + if (!blockId.startsWith("shuffle_")) { + throw new Exception("Block " + blockId + " is not a shuffle block") + } + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = math.abs(blockId.hashCode) + val dirId = hash % localDirs.length + val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) + val file = new File(subDir, blockId) + return file.getAbsolutePath + } + } + val sender = new ShuffleSender(port, pResovler) + } +} diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index f44d37a91f..3e60860b3e 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -1,8 +1,9 @@ package spark.rdd -import scala.collection.mutable.HashMap import scala.reflect.ClassTag + import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext} +import spark.storage.BlockManager private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition { val index = idx @@ -12,12 +13,7 @@ private[spark] class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[String]) extends RDD[T](sc, Nil) { - @transient lazy val locations_ = { - val blockManager = SparkEnv.get.blockManager - /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ - val locations = blockManager.getLocations(blockIds) - HashMap(blockIds.zip(locations):_*) - } + @transient lazy val locations_ = BlockManager.blockIdsToExecutorLocations(blockIds, SparkEnv.get) override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => { new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 700a4160c8..efd29fa561 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -9,6 +9,7 @@ import org.apache.hadoop.util.ReflectionUtils import org.apache.hadoop.fs.Path import java.io.{File, IOException, EOFException} import java.text.NumberFormat +import spark.deploy.SparkHadoopUtil private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -22,13 +23,20 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) override def getPartitions: Array[Partition] = { - val dirContents = fs.listStatus(new Path(checkpointPath)) - val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted - val numPartitions = partitionFiles.size - if (numPartitions > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || - ! partitionFiles(numPartitions-1).endsWith(CheckpointRDD.splitIdToFile(numPartitions-1)))) { - throw new SparkException("Invalid checkpoint directory: " + checkpointPath) - } + val cpath = new Path(checkpointPath) + val numPartitions = + // listStatus can throw exception if path does not exist. + if (fs.exists(cpath)) { + val dirContents = fs.listStatus(cpath) + val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted + val numPart = partitionFiles.size + if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || + ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { + throw new SparkException("Invalid checkpoint directory: " + checkpointPath) + } + numPart + } else 0 + Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) } @@ -36,7 +44,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) checkpointData.get.cpFile = Some(checkpointPath) override def getPreferredLocations(split: Partition): Seq[String] = { - val status = fs.getFileStatus(new Path(checkpointPath)) + val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))) val locations = fs.getFileBlockLocations(status, 0, status.getLen) locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") } @@ -59,7 +67,7 @@ private[spark] object CheckpointRDD extends Logging { def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { val outputDir = new Path(path) - val fs = outputDir.getFileSystem(new Configuration()) + val fs = outputDir.getFileSystem(SparkHadoopUtil.newConfiguration()) val finalOutputName = splitIdToFile(ctx.splitId) val finalOutputPath = new Path(outputDir, finalOutputName) @@ -84,6 +92,7 @@ private[spark] object CheckpointRDD extends Logging { if (!fs.rename(tempOutputPath, finalOutputPath)) { if (!fs.exists(finalOutputPath)) { + logInfo("Deleting tempOutputPath " + tempOutputPath) fs.delete(tempOutputPath, false) throw new IOException("Checkpoint failed: failed to save output of task: " + ctx.attemptId + " and final output path does not exist") @@ -96,7 +105,7 @@ private[spark] object CheckpointRDD extends Logging { } def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { - val fs = path.getFileSystem(new Configuration()) + val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration()) val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt val fileInputStream = fs.open(path, bufferSize) val serializer = SparkEnv.get.serializer.newInstance() @@ -118,7 +127,7 @@ private[spark] object CheckpointRDD extends Logging { val sc = new SparkContext(cluster, "CheckpointRDD Test") val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) val path = new Path(hdfsPath, "temp") - val fs = path.getFileSystem(new Configuration()) + val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration()) sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index a6235491ca..8966f9f86e 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -6,7 +6,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer -import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext} +import spark.{Aggregator, Partition, Partitioner, RDD, SparkEnv, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -49,12 +49,17 @@ private[spark] class CoGroupAggregator * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output. - * @param mapSideCombine flag indicating whether to merge values before shuffle step. + * @param mapSideCombine flag indicating whether to merge values before shuffle step. If the flag + * is on, Spark does an extra pass over the data on the map side to merge + * all values belonging to the same key together. This can reduce the amount + * of data shuffled if and only if the number of distinct keys is very small, + * and the ratio of key size to value size is also very small. */ class CoGroupedRDD[K]( @transient var rdds: Seq[RDD[(K, _)]], part: Partitioner, - val mapSideCombine: Boolean = true) + val mapSideCombine: Boolean = false, + val serializerClass: String = null) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { private val aggr = new CoGroupAggregator @@ -68,9 +73,9 @@ class CoGroupedRDD[K]( logInfo("Adding shuffle dependency with " + rdd) if (mapSideCombine) { val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) - new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part) + new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part, serializerClass) } else { - new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part) + new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part, serializerClass) } } } @@ -112,6 +117,7 @@ class CoGroupedRDD[K]( } } + val ser = SparkEnv.get.serializerManager.get(serializerClass) for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { // Read them from the parent @@ -124,12 +130,12 @@ class CoGroupedRDD[K]( val fetcher = SparkEnv.get.shuffleFetcher if (mapSideCombine) { // With map side combine on, for each key, the shuffle fetcher returns a list of values. - fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics).foreach { + fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics, ser).foreach { case (key, values) => getSeq(key)(depNum) ++= values } } else { // With map side combine off, for each key the shuffle fetcher returns a single value. - fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics).foreach { + fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics, ser).foreach { case (key, value) => getSeq(key)(depNum) += value } } diff --git a/core/src/main/scala/spark/rdd/EmptyRDD.scala b/core/src/main/scala/spark/rdd/EmptyRDD.scala new file mode 100644 index 0000000000..e4dd3a7fa7 --- /dev/null +++ b/core/src/main/scala/spark/rdd/EmptyRDD.scala @@ -0,0 +1,16 @@ +package spark.rdd + +import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext} + + +/** + * An RDD that is empty, i.e. has no element in it. + */ +class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) { + + override def getPartitions: Array[Partition] = Array.empty + + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + throw new UnsupportedOperationException("empty RDD") + } +} diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala new file mode 100644 index 0000000000..a50f407737 --- /dev/null +++ b/core/src/main/scala/spark/rdd/JdbcRDD.scala @@ -0,0 +1,103 @@ +package spark.rdd + +import java.sql.{Connection, ResultSet} + +import spark.{Logging, Partition, RDD, SparkContext, TaskContext} +import spark.util.NextIterator + +private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { + override def index = idx +} + +/** + * An RDD that executes an SQL query on a JDBC connection and reads results. + * For usage example, see test case JdbcRDDSuite. + * + * @param getConnection a function that returns an open Connection. + * The RDD takes care of closing the connection. + * @param sql the text of the query. + * The query must contain two ? placeholders for parameters used to partition the results. + * E.g. "select title, author from books where ? <= id and id <= ?" + * @param lowerBound the minimum value of the first placeholder + * @param upperBound the maximum value of the second placeholder + * The lower and upper bounds are inclusive. + * @param numPartitions the number of partitions. + * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + * the query would be executed twice, once with (1, 10) and once with (11, 20) + * @param mapRow a function from a ResultSet to a single row of the desired result type(s). + * This should only call getInt, getString, etc; the RDD takes care of calling next. + * The default maps a ResultSet to an array of Object. + */ +class JdbcRDD[T: ClassManifest]( + sc: SparkContext, + getConnection: () => Connection, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int, + mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _) + extends RDD[T](sc, Nil) with Logging { + + override def getPartitions: Array[Partition] = { + // bounds are inclusive, hence the + 1 here and - 1 on end + val length = 1 + upperBound - lowerBound + (0 until numPartitions).map(i => { + val start = lowerBound + ((i * length) / numPartitions).toLong + val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1 + new JdbcPartition(i, start, end) + }).toArray + } + + override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { + context.addOnCompleteCallback{ () => closeIfNeeded() } + val part = thePart.asInstanceOf[JdbcPartition] + val conn = getConnection() + val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + + // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results, + // rather than pulling entire resultset into memory. + // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html + if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) { + stmt.setFetchSize(Integer.MIN_VALUE) + logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ") + } + + stmt.setLong(1, part.lower) + stmt.setLong(2, part.upper) + val rs = stmt.executeQuery() + + override def getNext: T = { + if (rs.next()) { + mapRow(rs) + } else { + finished = true + null.asInstanceOf[T] + } + } + + override def close() { + try { + if (null != rs && ! rs.isClosed()) rs.close() + } catch { + case e: Exception => logWarning("Exception closing resultset", e) + } + try { + if (null != stmt && ! stmt.isClosed()) stmt.close() + } catch { + case e: Exception => logWarning("Exception closing statement", e) + } + try { + if (null != conn && ! stmt.isClosed()) conn.close() + logInfo("closed connection") + } catch { + case e: Exception => logWarning("Exception closing connection", e) + } + } + } +} + +object JdbcRDD { + def resultSetToObjectArray(rs: ResultSet) = { + Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) + } +} diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index bdd974590a..901d01ef30 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -57,7 +57,7 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopPartition] val conf = confBroadcast.value.value - val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) + val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance if (format.isInstanceOf[Configurable]) { diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 34d32eb85a..349e6162c4 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -10,6 +10,7 @@ import scala.io.Source import scala.reflect.ClassTag import spark.{RDD, SparkEnv, Partition, TaskContext} +import spark.broadcast.Broadcast /** @@ -19,14 +20,21 @@ import spark.{RDD, SparkEnv, Partition, TaskContext} class PipedRDD[T: ClassTag]( prev: RDD[T], command: Seq[String], - envVars: Map[String, String]) + envVars: Map[String, String], + printPipeContext: (String => Unit) => Unit, + printRDDElement: (T, String => Unit) => Unit) extends RDD[String](prev) { - def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) - // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) + def this( + prev: RDD[T], + command: String, + envVars: Map[String, String] = Map(), + printPipeContext: (String => Unit) => Unit = null, + printRDDElement: (T, String => Unit) => Unit = null) = + this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement) + override def getPartitions: Array[Partition] = firstParent[T].partitions @@ -53,8 +61,17 @@ class PipedRDD[T: ClassTag]( override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) + + // input the pipe context firstly + if (printPipeContext != null) { + printPipeContext(out.println(_)) + } for (elem <- firstParent[T].iterator(split, context)) { - out.println(elem) + if (printRDDElement != null) { + printRDDElement(elem, out.println(_)) + } else { + out.println(elem) + } } out.close() } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 4e33b7dd5c..c7d1926b83 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -3,6 +3,7 @@ package spark.rdd import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext} import spark.SparkContext._ + private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { override val index = idx override def hashCode(): Int = idx @@ -12,13 +13,15 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { * The resulting RDD from a shuffle (e.g. repartitioning of data). * @param prev the parent RDD. * @param part the partitioner used to partition the RDD + * @param serializerClass class name of the serializer to use. * @tparam K the key class. * @tparam V the value class. */ class ShuffledRDD[K, V]( @transient prev: RDD[(K, V)], - part: Partitioner) - extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) { + part: Partitioner, + serializerClass: String = null) + extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part, serializerClass))) { override val partitioner = Some(part) @@ -28,6 +31,7 @@ class ShuffledRDD[K, V]( override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics) + SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics, + SparkEnv.get.serializerManager.get(serializerClass)) } } diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala index 5e56900b18..9274839bca 100644 --- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala @@ -15,6 +15,7 @@ import spark.SparkEnv import spark.ShuffleDependency import spark.OneToOneDependency + /** * An optimized version of cogroup for set difference/subtraction. * @@ -34,7 +35,9 @@ import spark.OneToOneDependency private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( @transient var rdd1: RDD[(K, V)], @transient var rdd2: RDD[(K, W)], - part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) { + part: Partitioner, + val serializerClass: String = null) + extends RDD[(K, V)](rdd1.context, Nil) { override def getDependencies: Seq[Dependency[_]] = { Seq(rdd1, rdd2).map { rdd => @@ -43,7 +46,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( new OneToOneDependency(rdd) } else { logInfo("Adding shuffle dependency with " + rdd) - new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part) + new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part, serializerClass) } } } @@ -68,6 +71,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { val partition = p.asInstanceOf[CoGroupPartition] + val serializer = SparkEnv.get.serializerManager.get(serializerClass) val map = new JHashMap[K, ArrayBuffer[V]] def getSeq(k: K): ArrayBuffer[V] = { val seq = map.get(k) @@ -80,12 +84,16 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( } } def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => + case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { for (t <- rdd.iterator(itsSplit, context)) op(t.asInstanceOf[(K, V)]) - case ShuffleCoGroupSplitDep(shuffleId) => - for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics)) + } + case ShuffleCoGroupSplitDep(shuffleId) => { + val iter = SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, + context.taskMetrics, serializer) + for (t <- iter) op(t.asInstanceOf[(K, V)]) + } } // the first dep is rdd1; add all values to the map integrate(partition.deps(0), t => getSeq(t._1) += t._2) diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala new file mode 100644 index 0000000000..b234428ab2 --- /dev/null +++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala @@ -0,0 +1,138 @@ +package spark.rdd + +import spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext} +import java.io.{ObjectOutputStream, IOException} + +private[spark] class ZippedPartitionsPartition( + idx: Int, + @transient rdds: Seq[RDD[_]]) + extends Partition { + + override val index: Int = idx + var partitionValues = rdds.map(rdd => rdd.partitions(idx)) + def partitions = partitionValues + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + partitionValues = rdds.map(rdd => rdd.partitions(idx)) + oos.defaultWriteObject() + } +} + +abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( + sc: SparkContext, + var rdds: Seq[RDD[_]]) + extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { + + override def getPartitions: Array[Partition] = { + val sizes = rdds.map(x => x.partitions.size) + if (!sizes.forall(x => x == sizes(0))) { + throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") + } + val array = new Array[Partition](sizes(0)) + for (i <- 0 until sizes(0)) { + array(i) = new ZippedPartitionsPartition(i, rdds) + } + array + } + + override def getPreferredLocations(s: Partition): Seq[String] = { + // Note that as number of rdd's increase and/or number of slaves in cluster increase, the computed preferredLocations below + // become diminishingly small : so we might need to look at alternate strategies to alleviate this. + // If there are no (or very small number of preferred locations), we will end up transferred the blocks to 'any' node in the + // cluster - paying with n/w and cache cost. + // Maybe pick a node which figures max amount of time ? + // Choose node which is hosting 'larger' of some subset of blocks ? + // Look at rack locality to ensure chosen host is atleast rack local to both hosting node ?, etc (would be good to defer this if possible) + val splits = s.asInstanceOf[ZippedPartitionsPartition].partitions + val rddSplitZip = rdds.zip(splits) + + // exact match. + val exactMatchPreferredLocations = rddSplitZip.map(x => x._1.preferredLocations(x._2)) + val exactMatchLocations = exactMatchPreferredLocations.reduce((x, y) => x.intersect(y)) + + // Remove exact match and then do host local match. + val exactMatchHosts = exactMatchLocations.map(Utils.parseHostPort(_)._1) + val matchPreferredHosts = exactMatchPreferredLocations.map(locs => locs.map(Utils.parseHostPort(_)._1)) + .reduce((x, y) => x.intersect(y)) + val otherNodeLocalLocations = matchPreferredHosts.filter { s => !exactMatchHosts.contains(s) } + + otherNodeLocalLocations ++ exactMatchLocations + } + + override def clearDependencies() { + super.clearDependencies() + rdds = null + } +} + +class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]( + sc: SparkContext, + f: (Iterator[A], Iterator[B]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B]) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions + f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + } +} + +class ZippedPartitionsRDD3 + [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest]( + sc: SparkContext, + f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B], + var rdd3: RDD[C]) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions + f(rdd1.iterator(partitions(0), context), + rdd2.iterator(partitions(1), context), + rdd3.iterator(partitions(2), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + rdd3 = null + } +} + +class ZippedPartitionsRDD4 + [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest]( + sc: SparkContext, + f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B], + var rdd3: RDD[C], + var rdd4: RDD[D]) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions + f(rdd1.iterator(partitions(0), context), + rdd2.iterator(partitions(1), context), + rdd3.iterator(partitions(2), context), + rdd4.iterator(partitions(3), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + rdd3 = null + rdd4 = null + } +} diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index 1b438cd505..be05fb71f9 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -1,5 +1,7 @@ package spark.rdd +import spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext} + import java.io.{ObjectOutputStream, IOException} import scala.reflect.ClassTag @@ -50,8 +52,27 @@ class ZippedRDD[T: ClassTag, U: ClassTag]( } override def getPreferredLocations(s: Partition): Seq[String] = { + // Note that as number of slaves in cluster increase, the computed preferredLocations can become small : so we might need + // to look at alternate strategies to alleviate this. (If there are no (or very small number of preferred locations), we + // will end up transferred the blocks to 'any' node in the cluster - paying with n/w and cache cost. + // Maybe pick one or the other ? (so that atleast one block is local ?). + // Choose node which is hosting 'larger' of the blocks ? + // Look at rack locality to ensure chosen host is atleast rack local to both hosting node ?, etc (would be good to defer this if possible) val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions - rdd1.preferredLocations(partition1).intersect(rdd2.preferredLocations(partition2)) + val pref1 = rdd1.preferredLocations(partition1) + val pref2 = rdd2.preferredLocations(partition2) + + // exact match - instance local and host local. + val exactMatchLocations = pref1.intersect(pref2) + + // remove locations which are already handled via exactMatchLocations, and intersect where both partitions are node local. + val otherNodeLocalPref1 = pref1.filter(loc => ! exactMatchLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1) + val otherNodeLocalPref2 = pref2.filter(loc => ! exactMatchLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1) + val otherNodeLocalLocations = otherNodeLocalPref1.intersect(otherNodeLocalPref2) + + + // Can have mix of instance local (hostPort) and node local (host) locations as preference ! + exactMatchLocations ++ otherNodeLocalLocations } override def clearDependencies() { diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala index 5a4e9a582d..105eaecb22 100644 --- a/core/src/main/scala/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/spark/scheduler/ActiveJob.scala @@ -2,6 +2,8 @@ package spark.scheduler import spark.TaskContext +import java.util.Properties + /** * Tracks information about an active job in the DAGScheduler. */ @@ -11,7 +13,8 @@ private[spark] class ActiveJob( val func: (TaskContext, Iterator[_]) => _, val partitions: Array[Int], val callSite: String, - val listener: JobListener) { + val listener: JobListener, + val properties: Properties) { val numPartitions = partitions.length val finished = Array.fill[Boolean](numPartitions)(false) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b838cf84a8..1164c40c43 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -4,6 +4,7 @@ import cluster.TaskInfo import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.TimeUnit +import java.util.Properties import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.reflect.ClassTag @@ -13,7 +14,7 @@ import spark.executor.TaskMetrics import spark.partial.ApproximateActionListener import spark.partial.ApproximateEvaluator import spark.partial.PartialResult -import spark.storage.BlockManagerMaster +import spark.storage.{BlockManager, BlockManagerMaster} import spark.util.{MetadataCleaner, TimeStampedHashMap} /** @@ -51,6 +52,11 @@ class DAGScheduler( eventQueue.put(ExecutorLost(execId)) } + // Called by TaskScheduler when a host is added + override def executorGained(execId: String, hostPort: String) { + eventQueue.put(ExecutorGained(execId, hostPort)) + } + // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. override def taskSetFailed(taskSet: TaskSet, reason: String) { eventQueue.put(TaskSetFailed(taskSet, reason)) @@ -89,6 +95,8 @@ class DAGScheduler( // stray messages to detect. val failedGeneration = new HashMap[String, Long] + val idToActiveJob = new HashMap[Int, ActiveJob] + val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures @@ -113,9 +121,8 @@ class DAGScheduler( private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray - cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { - locations => locations.map(_.ip).toList - }.toArray + val locs = BlockManager.blockIdsToExecutorLocations(blockIds, env, blockManagerMaster) + cacheLocs(rdd.id) = blockIds.map(locs.getOrElse(_, Nil)) } cacheLocs(rdd.id) } @@ -222,13 +229,14 @@ class DAGScheduler( partitions: Seq[Int], callSite: String, allowLocal: Boolean, - resultHandler: (Int, U) => Unit) + resultHandler: (Int, U) => Unit, + properties: Properties = null) : (JobSubmitted, JobWaiter[U]) = { assert(partitions.size > 0) val waiter = new JobWaiter(partitions.size, resultHandler) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter) + val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) return (toSubmit, waiter) } @@ -238,13 +246,14 @@ class DAGScheduler( partitions: Seq[Int], callSite: String, allowLocal: Boolean, - resultHandler: (Int, U) => Unit) + resultHandler: (Int, U) => Unit, + properties: Properties = null) { if (partitions.size == 0) { return } val (toSubmit, waiter) = prepareJob( - finalRdd, func, partitions, callSite, allowLocal, resultHandler) + finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties) eventQueue.put(toSubmit) waiter.awaitResult() match { case JobSucceeded => {} @@ -259,13 +268,14 @@ class DAGScheduler( func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], callSite: String, - timeout: Long) + timeout: Long, + properties: Properties = null) : PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray - eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener)) + eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener, properties)) return listener.awaitResult() // Will throw an exception if the job fails } @@ -275,10 +285,10 @@ class DAGScheduler( */ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { - case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) => + case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) => val runId = nextRunId.getAndIncrement() val finalStage = newStage(finalRDD, None, runId) - val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) + val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + " output partitions (allowLocal=" + allowLocal + ")") @@ -289,15 +299,22 @@ class DAGScheduler( // Compute very short actions like first() or take() with no parent stages locally. runLocally(job) } else { + sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties))) + idToActiveJob(runId) = job activeJobs += job resultStageToJob(finalStage) = job submitStage(finalStage) } + case ExecutorGained(execId, hostPort) => + handleExecutorGained(execId, hostPort) + case ExecutorLost(execId) => handleExecutorLost(execId) case completion: CompletionEvent => + sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion.task, + completion.reason, completion.taskInfo, completion.taskMetrics))) handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason) => @@ -308,6 +325,7 @@ class DAGScheduler( for (job <- activeJobs) { val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error)))) } return true } @@ -455,11 +473,13 @@ class DAGScheduler( } } if (tasks.size > 0) { + sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size))) logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") myPending ++= tasks logDebug("New pending tasks: " + myPending) + val properties = idToActiveJob(stage.priority).properties taskSched.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority)) + new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority, properties)) if (!stage.submissionTime.isDefined) { stage.submissionTime = Some(System.currentTimeMillis()) } @@ -508,6 +528,7 @@ class DAGScheduler( activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobSucceeded))) } job.listener.taskSucceeded(rt.outputId, event.result) } @@ -631,6 +652,14 @@ class DAGScheduler( "(generation " + currentGeneration + ")") } } + + private def handleExecutorGained(execId: String, hostPort: String) { + // remove from failedGeneration(execId) ? + if (failedGeneration.contains(execId)) { + logInfo("Host gained which was in lost list earlier: " + hostPort) + failedGeneration -= execId + } + } /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set @@ -640,7 +669,9 @@ class DAGScheduler( val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) - job.listener.jobFailed(new SparkException("Job failed: " + reason)) + val error = new SparkException("Job failed: " + reason) + job.listener.jobFailed(error) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error)))) activeJobs -= job resultStageToJob -= resultStage } diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index ed0b9bf178..acad915f13 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -1,5 +1,7 @@ package spark.scheduler +import java.util.Properties + import spark.scheduler.cluster.TaskInfo import scala.collection.mutable.Map @@ -20,7 +22,8 @@ private[spark] case class JobSubmitted( partitions: Array[Int], allowLocal: Boolean, callSite: String, - listener: JobListener) + listener: JobListener, + properties: Properties = null) extends DAGSchedulerEvent private[spark] case class CompletionEvent( @@ -32,6 +35,10 @@ private[spark] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent +private[spark] case class ExecutorGained(execId: String, hostPort: String) extends DAGSchedulerEvent { + Utils.checkHostPort(hostPort, "Required hostport") +} + private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala new file mode 100644 index 0000000000..287f731787 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala @@ -0,0 +1,156 @@ +package spark.scheduler + +import spark.Logging +import scala.collection.immutable.Set +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.hadoop.util.ReflectionUtils +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.conf.Configuration +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConversions._ + + +/** + * Parses and holds information about inputFormat (and files) specified as a parameter. + */ +class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_], + val path: String) extends Logging { + + var mapreduceInputFormat: Boolean = false + var mapredInputFormat: Boolean = false + + validate() + + override def toString(): String = { + "InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path + } + + override def hashCode(): Int = { + var hashCode = inputFormatClazz.hashCode + hashCode = hashCode * 31 + path.hashCode + hashCode + } + + // Since we are not doing canonicalization of path, this can be wrong : like relative vs absolute path + // .. which is fine, this is best case effort to remove duplicates - right ? + override def equals(other: Any): Boolean = other match { + case that: InputFormatInfo => { + // not checking config - that should be fine, right ? + this.inputFormatClazz == that.inputFormatClazz && + this.path == that.path + } + case _ => false + } + + private def validate() { + logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path) + + try { + if (classOf[org.apache.hadoop.mapreduce.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { + logDebug("inputformat is from mapreduce package") + mapreduceInputFormat = true + } + else if (classOf[org.apache.hadoop.mapred.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { + logDebug("inputformat is from mapred package") + mapredInputFormat = true + } + else { + throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + + " is NOT a supported input format ? does not implement either of the supported hadoop api's") + } + } + catch { + case e: ClassNotFoundException => { + throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e) + } + } + } + + + // This method does not expect failures, since validate has already passed ... + private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = { + val conf = new JobConf(configuration) + FileInputFormat.setInputPaths(conf, path) + + val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] = + ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[ + org.apache.hadoop.mapreduce.InputFormat[_, _]] + val job = new Job(conf) + + val retval = new ArrayBuffer[SplitInfo]() + val list = instance.getSplits(job) + for (split <- list) { + retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) + } + + return retval.toSet + } + + // This method does not expect failures, since validate has already passed ... + private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = { + val jobConf = new JobConf(configuration) + FileInputFormat.setInputPaths(jobConf, path) + + val instance: org.apache.hadoop.mapred.InputFormat[_, _] = + ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], jobConf).asInstanceOf[ + org.apache.hadoop.mapred.InputFormat[_, _]] + + val retval = new ArrayBuffer[SplitInfo]() + instance.getSplits(jobConf, jobConf.getNumMapTasks()).foreach( + elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem) + ) + + return retval.toSet + } + + private def findPreferredLocations(): Set[SplitInfo] = { + logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat + + ", inputFormatClazz : " + inputFormatClazz) + if (mapreduceInputFormat) { + return prefLocsFromMapreduceInputFormat() + } + else { + assert(mapredInputFormat) + return prefLocsFromMapredInputFormat() + } + } +} + + + + +object InputFormatInfo { + /** + Computes the preferred locations based on input(s) and returned a location to block map. + Typical use of this method for allocation would follow some algo like this + (which is what we currently do in YARN branch) : + a) For each host, count number of splits hosted on that host. + b) Decrement the currently allocated containers on that host. + c) Compute rack info for each host and update rack -> count map based on (b). + d) Allocate nodes based on (c) + e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node + (even if data locality on that is very high) : this is to prevent fragility of job if a single + (or small set of) hosts go down. + + go to (a) until required nodes are allocated. + + If a node 'dies', follow same procedure. + + PS: I know the wording here is weird, hopefully it makes some sense ! + */ + def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] = { + + val nodeToSplit = new HashMap[String, HashSet[SplitInfo]] + for (inputSplit <- formats) { + val splits = inputSplit.findPreferredLocations() + + for (split <- splits){ + val location = split.hostLocation + val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo]) + set += split + } + } + + nodeToSplit + } +} diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala new file mode 100644 index 0000000000..178bfaba3d --- /dev/null +++ b/core/src/main/scala/spark/scheduler/JobLogger.scala @@ -0,0 +1,306 @@ +package spark.scheduler
+
+import java.io.PrintWriter
+import java.io.File
+import java.io.FileNotFoundException
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+import scala.collection.mutable.{Map, HashMap, ListBuffer}
+import scala.io.Source
+import spark._
+import spark.executor.TaskMetrics
+import spark.scheduler.cluster.TaskInfo
+
+// Used to record runtime information for each job, including RDD graph
+// tasks' start/stop shuffle information and information from outside
+
+class JobLogger(val logDirName: String) extends SparkListener with Logging {
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark"
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+ def this() = this(String.valueOf(System.currentTimeMillis()))
+
+ def getLogDir = logDir
+ def getJobIDtoPrintWriter = jobIDToPrintWriter
+ def getStageIDToJobID = stageIDToJobID
+ def getJobIDToStages = jobIDToStages
+ def getEventQueue = eventQueue
+
+ new Thread("JobLogger") {
+ setDaemon(true)
+ override def run() {
+ while (true) {
+ val event = eventQueue.take
+ logDebug("Got event of type " + event.getClass.getName)
+ event match {
+ case SparkListenerJobStart(job, properties) =>
+ processJobStartEvent(job, properties)
+ case SparkListenerStageSubmitted(stage, taskSize) =>
+ processStageSubmittedEvent(stage, taskSize)
+ case StageCompleted(stageInfo) =>
+ processStageCompletedEvent(stageInfo)
+ case SparkListenerJobEnd(job, result) =>
+ processJobEndEvent(job, result)
+ case SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics) =>
+ processTaskEndEvent(task, reason, taskInfo, taskMetrics)
+ case _ =>
+ }
+ }
+ }
+ }.start()
+
+ // Create a folder for log files, the folder's name is the creation time of the jobLogger
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ logError("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ // Create a log file for one job, the file name is the jobID
+ protected def createLogWriter(jobID: Int) {
+ try{
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ // Close log file, and clean the stage relationship in stageIDToJobID
+ protected def closeLogWriter(jobID: Int) =
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+
+ // Write log information to log file, withTime parameter controls whether to recored
+ // time stamp for the information
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.priority == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ case _ => rddList ++= getRddsInStage(dep.rdd)
+ }
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ // Generate indents and convert to String
+ protected def indentString(indent: Int) = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ protected def getRddName(rdd: RDD[_]) = {
+ var rddName = rdd.getClass.getName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+ }
+
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
+ var stageInfo: String = ""
+ if (stage.isShuffleMap) {
+ stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
+ stage.shuffleDep.get.shuffleId
+ }else{
+ stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.priority == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
+ } else
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.priority, false)
+ }
+
+ // Record task metrics into job log files
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics =
+ taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics =
+ taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ eventQueue.put(stageSubmitted)
+ }
+
+ protected def processStageSubmittedEvent(stage: Stage, taskSize: Int) {
+ stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED" + " TASK_SIZE=" + taskSize)
+ }
+
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ eventQueue.put(stageCompleted)
+ }
+
+ protected def processStageCompletedEvent(stageInfo: StageInfo) {
+ stageLogInfo(stageInfo.stage.id, "STAGE_ID=" +
+ stageInfo.stage.id + " STATUS=COMPLETED")
+
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ eventQueue.put(taskEnd)
+ }
+
+ protected def processTaskEndEvent(task: Task[_], reason: TaskEndReason,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ eventQueue.put(jobEnd)
+ }
+
+ protected def processJobEndEvent(job: ActiveJob, reason: JobResult) {
+ var info = "JOB_ID=" + job.runId
+ reason match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.runId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.runId)
+ }
+
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val annotation = properties.getProperty("spark.job.annotation", "")
+ jobLogInfo(jobID, annotation, false)
+ }
+ }
+
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ eventQueue.put(jobStart)
+ }
+
+ protected def processJobStartEvent(job: ActiveJob, properties: Properties) {
+ createLogWriter(job.runId)
+ recordJobProperties(job.runId, properties)
+ buildJobDep(job.runId, job.finalStage)
+ recordStageDep(job.runId)
+ recordStageDepGraph(job.runId, job.finalStage)
+ jobLogInfo(job.runId, "JOB_ID=" + job.runId + " STATUS=STARTED")
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index beb21a76fe..83166bce22 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -70,6 +70,13 @@ private[spark] class ResultTask[T, U]( rdd.partitions(partition) } + private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.toSet.toSeq + + { + // DEBUG code + preferredLocs.foreach (hostPort => Utils.checkHost(Utils.parseHostPort(hostPort)._1, "preferredLocs : " + preferredLocs)) + } + override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) metrics = Some(context.taskMetrics) @@ -80,7 +87,7 @@ private[spark] class ResultTask[T, U]( } } - override def preferredLocations: Seq[String] = locs + override def preferredLocations: Seq[String] = preferredLocs override def toString = "ResultTask(" + stageId + ", " + partition + ")" diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 36d087a4d0..95647389c3 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -13,9 +13,10 @@ import com.ning.compress.lzf.LZFInputStream import com.ning.compress.lzf.LZFOutputStream import spark._ -import executor.ShuffleWriteMetrics +import spark.executor.ShuffleWriteMetrics import spark.storage._ -import util.{TimeStampedHashMap, MetadataCleaner} +import spark.util.{TimeStampedHashMap, MetadataCleaner} + private[spark] object ShuffleMapTask { @@ -77,13 +78,20 @@ private[spark] class ShuffleMapTask( var rdd: RDD[_], var dep: ShuffleDependency[_,_], var partition: Int, - @transient var locs: Seq[String]) + @transient private var locs: Seq[String]) extends Task[MapStatus](stageId) with Externalizable with Logging { protected def this() = this(0, null, null, 0, null) + @transient private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.toSet.toSeq + + { + // DEBUG code + preferredLocs.foreach (hostPort => Utils.checkHost(Utils.parseHostPort(hostPort)._1, "preferredLocs : " + preferredLocs)) + } + var split = if (rdd == null) { null } else { @@ -121,40 +129,58 @@ private[spark] class ShuffleMapTask( val taskContext = new TaskContext(stageId, partition, attemptId) metrics = Some(taskContext.taskMetrics) + + val blockManager = SparkEnv.get.blockManager + var shuffle: ShuffleBlocks = null + var buckets: ShuffleWriterGroup = null + try { - // Partition the map output. - val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) + // Obtain all the block writers for shuffle blocks. + val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) + shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) + buckets = shuffle.acquireWriters(partition) + + // Write the map output to its associated buckets. for (elem <- rdd.iterator(split, taskContext)) { val pair = elem.asInstanceOf[(Any, Any)] val bucketId = dep.partitioner.getPartition(pair._1) - buckets(bucketId) += pair + buckets.writers(bucketId).write(pair) } - val compressedSizes = new Array[Byte](numOutputSplits) - - var totalBytes = 0l - - val blockManager = SparkEnv.get.blockManager - for (i <- 0 until numOutputSplits) { - val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i - // Get a Scala iterator from Java map - val iter: Iterator[(Any, Any)] = buckets(i).iterator - val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) + // Commit the writes. Get the size of each bucket block (total block size). + var totalBytes = 0L + val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => + writer.commit() + writer.close() + val size = writer.size() totalBytes += size - compressedSizes(i) = MapOutputTracker.compressSize(size) + MapOutputTracker.compressSize(size) } + + // Update shuffle metrics. val shuffleMetrics = new ShuffleWriteMetrics shuffleMetrics.shuffleBytesWritten = totalBytes metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) return new MapStatus(blockManager.blockManagerId, compressedSizes) + } catch { case e: Exception => + // If there is an exception from running the task, revert the partial writes + // and throw the exception upstream to Spark. + if (buckets != null) { + buckets.writers.foreach(_.revertPartialWrites()) + } + throw e } finally { + // Release the writers back to the shuffle block manager. + if (shuffle != null && buckets != null) { + shuffle.releaseWriters(buckets) + } // Execute the callbacks on task completion. taskContext.executeOnCompleteCallbacks() } } - override def preferredLocations: Seq[String] = locs + override def preferredLocations: Seq[String] = preferredLocs override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) } diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala index a65140b145..bac984b5c9 100644 --- a/core/src/main/scala/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/spark/scheduler/SparkListener.scala @@ -1,27 +1,59 @@ package spark.scheduler +import java.util.Properties import spark.scheduler.cluster.TaskInfo import spark.util.Distribution -import spark.{Utils, Logging} +import spark.{Logging, SparkContext, TaskEndReason, Utils} import spark.executor.TaskMetrics -trait SparkListener { - /** - * called when a stage is completed, with information on the completed stage - */ - def onStageCompleted(stageCompleted: StageCompleted) -} - sealed trait SparkListenerEvents +case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int) extends SparkListenerEvents + case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents +case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, + taskMetrics: TaskMetrics) extends SparkListenerEvents + +case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null) + extends SparkListenerEvents + +case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult) + extends SparkListenerEvents + +trait SparkListener { + /** + * Called when a stage is completed, with information on the completed stage + */ + def onStageCompleted(stageCompleted: StageCompleted) { } + + /** + * Called when a stage is submitted + */ + def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { } + + /** + * Called when a task ends + */ + def onTaskEnd(taskEnd: SparkListenerTaskEnd) { } + + /** + * Called when a job starts + */ + def onJobStart(jobStart: SparkListenerJobStart) { } + + /** + * Called when a job ends + */ + def onJobEnd(jobEnd: SparkListenerJobEnd) { } + +} /** * Simple SparkListener that logs a few summary statistics when each stage completes */ class StatsReportListener extends SparkListener with Logging { - def onStageCompleted(stageCompleted: StageCompleted) { + override def onStageCompleted(stageCompleted: StageCompleted) { import spark.scheduler.StatsReportListener._ implicit val sc = stageCompleted this.logInfo("Finished stage: " + stageCompleted.stageInfo) diff --git a/core/src/main/scala/spark/scheduler/SplitInfo.scala b/core/src/main/scala/spark/scheduler/SplitInfo.scala new file mode 100644 index 0000000000..6abfb7a1f7 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/SplitInfo.scala @@ -0,0 +1,61 @@ +package spark.scheduler + +import collection.mutable.ArrayBuffer + +// information about a specific split instance : handles both split instances. +// So that we do not need to worry about the differences. +class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String, + val length: Long, val underlyingSplit: Any) { + override def toString(): String = { + "SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + + ", hostLocation : " + hostLocation + ", path : " + path + + ", length : " + length + ", underlyingSplit " + underlyingSplit + } + + override def hashCode(): Int = { + var hashCode = inputFormatClazz.hashCode + hashCode = hashCode * 31 + hostLocation.hashCode + hashCode = hashCode * 31 + path.hashCode + // ignore overflow ? It is hashcode anyway ! + hashCode = hashCode * 31 + (length & 0x7fffffff).toInt + hashCode + } + + // This is practically useless since most of the Split impl's dont seem to implement equals :-( + // So unless there is identity equality between underlyingSplits, it will always fail even if it + // is pointing to same block. + override def equals(other: Any): Boolean = other match { + case that: SplitInfo => { + this.hostLocation == that.hostLocation && + this.inputFormatClazz == that.inputFormatClazz && + this.path == that.path && + this.length == that.length && + // other split specific checks (like start for FileSplit) + this.underlyingSplit == that.underlyingSplit + } + case _ => false + } +} + +object SplitInfo { + + def toSplitInfo(inputFormatClazz: Class[_], path: String, + mapredSplit: org.apache.hadoop.mapred.InputSplit): Seq[SplitInfo] = { + val retval = new ArrayBuffer[SplitInfo]() + val length = mapredSplit.getLength + for (host <- mapredSplit.getLocations) { + retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit) + } + retval + } + + def toSplitInfo(inputFormatClazz: Class[_], path: String, + mapreduceSplit: org.apache.hadoop.mapreduce.InputSplit): Seq[SplitInfo] = { + val retval = new ArrayBuffer[SplitInfo]() + val length = mapreduceSplit.getLength + for (host <- mapreduceSplit.getLocations) { + retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit) + } + retval + } +} diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index 552061e46b..7fc9e13fd9 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -26,7 +26,7 @@ private[spark] class Stage( val parents: List[Stage], val priority: Int) extends Logging { - + val isShuffleMap = shuffleDep != None val numPartitions = rdd.partitions.size val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) @@ -60,7 +60,7 @@ private[spark] class Stage( numAvailableOutputs -= 1 } } - + def removeOutputsOnExecutor(execId: String) { var becameUnavailable = false for (partition <- 0 until numPartitions) { diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala index d549b184b0..7787b54762 100644 --- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala @@ -10,6 +10,10 @@ package spark.scheduler private[spark] trait TaskScheduler { def start(): Unit + // Invoked after system has successfully initialized (typically in spark context). + // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc. + def postStartHook() { } + // Disconnect from the cluster. def stop(): Unit diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala index 771518dddf..b75d3736cf 100644 --- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala +++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala @@ -14,6 +14,9 @@ private[spark] trait TaskSchedulerListener { def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit + // A node was added to the cluster. + def executorGained(execId: String, hostPort: String): Unit + // A node was lost from the cluster. def executorLost(execId: String): Unit diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala index a3002ca477..e4b5fcaedb 100644 --- a/core/src/main/scala/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/spark/scheduler/TaskSet.scala @@ -1,11 +1,18 @@ package spark.scheduler +import java.util.Properties + /** * A set of tasks submitted together to the low-level TaskScheduler, usually representing * missing partitions of a particular stage. */ -private[spark] class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) { - val id: String = stageId + "." + attempt +private[spark] class TaskSet( + val tasks: Array[Task[_]], + val stageId: Int, + val attempt: Int, + val priority: Int, + val properties: Properties) { + val id: String = stageId + "." + attempt override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 26fdef101b..3a0c29b27f 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -1,6 +1,6 @@ package spark.scheduler.cluster -import java.io.{File, FileInputStream, FileOutputStream} +import java.lang.{Boolean => JBoolean} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -25,17 +25,45 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong // Threshold above which we warn user initial TaskSet may be starved val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong + // How often to revive offers in case there are pending tasks - that is how often to try to get + // tasks scheduled in case there are nodes available : default 0 is to disable it - to preserve existing behavior + // Note that this is required due to delayed scheduling due to data locality waits, etc. + // TODO: rename property ? + val TASK_REVIVAL_INTERVAL = System.getProperty("spark.tasks.revive.interval", "0").toLong + + /* + This property controls how aggressive we should be to modulate waiting for node local task scheduling. + To elaborate, currently there is a time limit (3 sec def) to ensure that spark attempts to wait for node locality of tasks before + scheduling on other nodes. We have modified this in yarn branch such that offers to task set happen in prioritized order : + node-local, rack-local and then others + But once all available node local (and no pref) tasks are scheduled, instead of waiting for 3 sec before + scheduling to other nodes (which degrades performance for time sensitive tasks and on larger clusters), we can + modulate that : to also allow rack local nodes or any node. The default is still set to HOST - so that previous behavior is + maintained. This is to allow tuning the tension between pulling rdd data off node and scheduling computation asap. + + TODO: rename property ? The value is one of + - NODE_LOCAL (default, no change w.r.t current behavior), + - RACK_LOCAL and + - ANY + + Note that this property makes more sense when used in conjugation with spark.tasks.revive.interval > 0 : else it is not very effective. + + Additional Note: For non trivial clusters, there is a 4x - 5x reduction in running time (in some of our experiments) based on whether + it is left at default NODE_LOCAL, RACK_LOCAL (if cluster is configured to be rack aware) or ANY. + If cluster is rack aware, then setting it to RACK_LOCAL gives best tradeoff and a 3x - 4x performance improvement while minimizing IO impact. + Also, it brings down the variance in running time drastically. + */ + val TASK_SCHEDULING_AGGRESSION = TaskLocality.parse(System.getProperty("spark.tasks.schedule.aggression", "NODE_LOCAL")) val activeTaskSets = new HashMap[String, TaskSetManager] - var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] val taskIdToTaskSetId = new HashMap[Long, String] val taskIdToExecutorId = new HashMap[Long, String] val taskSetTaskIds = new HashMap[String, HashSet[Long]] - var hasReceivedTask = false - var hasLaunchedTask = false - val starvationTimer = new Timer(true) + @volatile private var hasReceivedTask = false + @volatile private var hasLaunchedTask = false + private val starvationTimer = new Timer(true) // Incrementing Mesos task IDs val nextTaskId = new AtomicLong(0) @@ -43,11 +71,16 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // Which executor IDs we have executors on val activeExecutorIds = new HashSet[String] + // TODO: We might want to remove this and merge it with execId datastructures - but later. + // Which hosts in the cluster are alive (contains hostPort's) - used for process local and node local task locality. + private val hostPortsAlive = new HashSet[String] + private val hostToAliveHostPorts = new HashMap[String, HashSet[String]] + // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host - val executorsByHost = new HashMap[String, HashSet[String]] + private val executorsByHostPort = new HashMap[String, HashSet[String]] - val executorIdToHost = new HashMap[String, String] + private val executorIdToHostPort = new HashMap[String, String] // JAR server, if any JARs were added by the user to the SparkContext var jarServer: HttpServer = null @@ -62,24 +95,50 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val mapOutputTracker = SparkEnv.get.mapOutputTracker + var schedulableBuilder: SchedulableBuilder = null + var rootPool: Pool = null + override def setListener(listener: TaskSchedulerListener) { this.listener = listener } def initialize(context: SchedulerBackend) { backend = context + //default scheduler is FIFO + val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO") + //temporarily set rootPool name to empty + rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0) + schedulableBuilder = { + schedulingMode match { + case "FIFO" => + new FIFOSchedulableBuilder(rootPool) + case "FAIR" => + new FairSchedulableBuilder(rootPool) + } + } + schedulableBuilder.buildPools() + // resolve executorId to hostPort mapping. + def executorToHostPort(executorId: String, defaultHostPort: String): String = { + executorIdToHostPort.getOrElse(executorId, defaultHostPort) + } + + // Unfortunately, this means that SparkEnv is indirectly referencing ClusterScheduler + // Will that be a design violation ? + SparkEnv.get.executorIdToHostPort = Some(executorToHostPort) } + def newTaskId(): Long = nextTaskId.getAndIncrement() override def start() { backend.start() - if (System.getProperty("spark.speculation", "false") == "true") { + if (JBoolean.getBoolean("spark.speculation")) { new Thread("ClusterScheduler speculation check") { setDaemon(true) override def run() { + logInfo("Starting speculative execution thread") while (true) { try { Thread.sleep(SPECULATION_INTERVAL) @@ -91,15 +150,36 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } }.start() } + + + // Change to always run with some default if TASK_REVIVAL_INTERVAL <= 0 ? + if (TASK_REVIVAL_INTERVAL > 0) { + new Thread("ClusterScheduler task offer revival check") { + setDaemon(true) + + override def run() { + logInfo("Starting speculative task offer revival thread") + while (true) { + try { + Thread.sleep(TASK_REVIVAL_INTERVAL) + } catch { + case e: InterruptedException => {} + } + + if (hasPendingTasks()) backend.reviveOffers() + } + } + }.start() + } } override def submitTasks(taskSet: TaskSet) { val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { - val manager = new TaskSetManager(this, taskSet) + val manager = new ClusterTaskSetManager(this, taskSet) activeTaskSets(taskSet.id) = manager - activeTaskSetsQueue += manager + schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) taskSetTaskIds(taskSet.id) = new HashSet[Long]() if (hasReceivedTask == false) { @@ -122,7 +202,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def taskSetFinished(manager: TaskSetManager) { this.synchronized { activeTaskSets -= manager.taskSet.id - activeTaskSetsQueue -= manager + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) taskSetTaskIds.remove(manager.taskSet.id) @@ -139,22 +220,128 @@ private[spark] class ClusterScheduler(val sc: SparkContext) SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { - executorIdToHost(o.executorId) = o.hostname - if (!executorsByHost.contains(o.hostname)) { - executorsByHost(o.hostname) = new HashSet() + // DEBUG Code + Utils.checkHostPort(o.hostPort) + + executorIdToHostPort(o.executorId) = o.hostPort + if (! executorsByHostPort.contains(o.hostPort)) { + executorsByHostPort(o.hostPort) = new HashSet[String]() } + + hostPortsAlive += o.hostPort + hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(o.hostPort)._1, new HashSet[String]).add(o.hostPort) + executorGained(o.executorId, o.hostPort) } // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + // merge availableCpus into nodeToAvailableCpus block ? val availableCpus = offers.map(o => o.cores).toArray + val nodeToAvailableCpus = { + val map = new HashMap[String, Int]() + for (offer <- offers) { + val hostPort = offer.hostPort + val cores = offer.cores + // DEBUG code + Utils.checkHostPort(hostPort) + + val host = Utils.parseHostPort(hostPort)._1 + + map.put(host, map.getOrElse(host, 0) + cores) + } + + map + } var launchedTask = false - for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { + val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() + for (manager <- sortedTaskSetQueue) + { + logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) + } + for (manager <- sortedTaskSetQueue) { + + // Split offers based on node local, rack local and off-rack tasks. + val processLocalOffers = new HashMap[String, ArrayBuffer[Int]]() + val nodeLocalOffers = new HashMap[String, ArrayBuffer[Int]]() + val rackLocalOffers = new HashMap[String, ArrayBuffer[Int]]() + val otherOffers = new HashMap[String, ArrayBuffer[Int]]() + + for (i <- 0 until offers.size) { + val hostPort = offers(i).hostPort + // DEBUG code + Utils.checkHostPort(hostPort) + + val numProcessLocalTasks = math.max(0, math.min(manager.numPendingTasksForHostPort(hostPort), availableCpus(i))) + if (numProcessLocalTasks > 0){ + val list = processLocalOffers.getOrElseUpdate(hostPort, new ArrayBuffer[Int]) + for (j <- 0 until numProcessLocalTasks) list += i + } + + val host = Utils.parseHostPort(hostPort)._1 + val numNodeLocalTasks = math.max(0, + // Remove process local tasks (which are also host local btw !) from this + math.min(manager.numPendingTasksForHost(hostPort) - numProcessLocalTasks, nodeToAvailableCpus(host))) + if (numNodeLocalTasks > 0){ + val list = nodeLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) + for (j <- 0 until numNodeLocalTasks) list += i + } + + val numRackLocalTasks = math.max(0, + // Remove node local tasks (which are also rack local btw !) from this + math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numProcessLocalTasks - numNodeLocalTasks, nodeToAvailableCpus(host))) + if (numRackLocalTasks > 0){ + val list = rackLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) + for (j <- 0 until numRackLocalTasks) list += i + } + if (numNodeLocalTasks <= 0 && numRackLocalTasks <= 0){ + // add to others list - spread even this across cluster. + val list = otherOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) + list += i + } + } + + val offersPriorityList = new ArrayBuffer[Int]( + processLocalOffers.size + nodeLocalOffers.size + rackLocalOffers.size + otherOffers.size) + + // First process local, then host local, then rack, then others + + // numNodeLocalOffers contains count of both process local and host offers. + val numNodeLocalOffers = { + val processLocalPriorityList = ClusterScheduler.prioritizeContainers(processLocalOffers) + offersPriorityList ++= processLocalPriorityList + + val nodeLocalPriorityList = ClusterScheduler.prioritizeContainers(nodeLocalOffers) + offersPriorityList ++= nodeLocalPriorityList + + processLocalPriorityList.size + nodeLocalPriorityList.size + } + val numRackLocalOffers = { + val rackLocalPriorityList = ClusterScheduler.prioritizeContainers(rackLocalOffers) + offersPriorityList ++= rackLocalPriorityList + rackLocalPriorityList.size + } + offersPriorityList ++= ClusterScheduler.prioritizeContainers(otherOffers) + + var lastLoop = false + val lastLoopIndex = TASK_SCHEDULING_AGGRESSION match { + case TaskLocality.NODE_LOCAL => numNodeLocalOffers + case TaskLocality.RACK_LOCAL => numRackLocalOffers + numNodeLocalOffers + case TaskLocality.ANY => offersPriorityList.size + } + do { launchedTask = false - for (i <- 0 until offers.size) { + var loopCount = 0 + for (i <- offersPriorityList) { val execId = offers(i).executorId - val host = offers(i).hostname - manager.slaveOffer(execId, host, availableCpus(i)) match { + val hostPort = offers(i).hostPort + + // If last loop and within the lastLoopIndex, expand scope - else use null (which will use default/existing) + val overrideLocality = if (lastLoop && loopCount < lastLoopIndex) TASK_SCHEDULING_AGGRESSION else null + + // If last loop, override waiting for host locality - we scheduled all local tasks already and there might be more available ... + loopCount += 1 + + manager.slaveOffer(execId, hostPort, availableCpus(i), overrideLocality) match { case Some(task) => tasks(i) += task val tid = task.taskId @@ -162,15 +349,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext) taskSetTaskIds(manager.taskSet.id) += tid taskIdToExecutorId(tid) = execId activeExecutorIds += execId - executorsByHost(host) += execId + executorsByHostPort(hostPort) += execId availableCpus(i) -= 1 launchedTask = true case None => {} + } + } + // Loop once more - when lastLoop = true, then we try to schedule task on all nodes irrespective of + // data locality (we still go in order of priority : but that would not change anything since + // if data local tasks had been available, we would have scheduled them already) + if (lastLoop) { + // prevent more looping + launchedTask = false + } else if (!lastLoop && !launchedTask) { + // Do this only if TASK_SCHEDULING_AGGRESSION != NODE_LOCAL + if (TASK_SCHEDULING_AGGRESSION != TaskLocality.NODE_LOCAL) { + // fudge launchedTask to ensure we loop once more + launchedTask = true + // dont loop anymore + lastLoop = true } } } while (launchedTask) } + if (tasks.size > 0) { hasLaunchedTask = true } @@ -223,6 +426,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.reviveOffers() } if (taskFailed) { + // Also revive offers if a task had failed for some reason other than host lost backend.reviveOffers() } @@ -256,29 +460,40 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (jarServer != null) { jarServer.stop() } + + // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. + // TODO: Do something better ! + Thread.sleep(5000L) } override def defaultParallelism() = backend.defaultParallelism() + // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { var shouldRevive = false synchronized { - for (ts <- activeTaskSetsQueue) { - shouldRevive |= ts.checkSpeculatableTasks() - } + shouldRevive = rootPool.checkSpeculatableTasks() } if (shouldRevive) { backend.reviveOffers() } } + // Check for pending tasks in all our active jobs. + def hasPendingTasks(): Boolean = { + synchronized { + rootPool.hasPendingTasks() + } + } + def executorLost(executorId: String, reason: ExecutorLossReason) { var failedExecutor: Option[String] = None + synchronized { if (activeExecutorIds.contains(executorId)) { - val host = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, host, reason)) + val hostPort = executorIdToHostPort(executorId) + logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) removeExecutor(executorId) failedExecutor = Some(executorId) } else { @@ -296,19 +511,104 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - /** Get a list of hosts that currently have executors */ - def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet - /** Remove an executor from all our data structures and mark it as lost */ private def removeExecutor(executorId: String) { activeExecutorIds -= executorId - val host = executorIdToHost(executorId) - val execs = executorsByHost.getOrElse(host, new HashSet) + val hostPort = executorIdToHostPort(executorId) + if (hostPortsAlive.contains(hostPort)) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + hostPortsAlive -= hostPort + hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(hostPort)._1, new HashSet[String]).remove(hostPort) + } + + val execs = executorsByHostPort.getOrElse(hostPort, new HashSet) execs -= executorId if (execs.isEmpty) { - executorsByHost -= host + executorsByHostPort -= hostPort } - executorIdToHost -= executorId - activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) + executorIdToHostPort -= executorId + rootPool.executorLost(executorId, hostPort) + } + + def executorGained(execId: String, hostPort: String) { + listener.executorGained(execId, hostPort) + } + + def getExecutorsAliveOnHost(host: String): Option[Set[String]] = { + Utils.checkHost(host) + + val retval = hostToAliveHostPorts.get(host) + if (retval.isDefined) { + return Some(retval.get.toSet) + } + + None + } + + def isExecutorAliveOnHostPort(hostPort: String): Boolean = { + // Even if hostPort is a host, it does not matter - it is just a specific check. + // But we do have to ensure that only hostPort get into hostPortsAlive ! + // So no check against Utils.checkHostPort + hostPortsAlive.contains(hostPort) + } + + // By default, rack is unknown + def getRackForHost(value: String): Option[String] = None + + // By default, (cached) hosts for rack is unknown + def getCachedHostsForRack(rack: String): Option[Set[String]] = None +} + + +object ClusterScheduler { + + // Used to 'spray' available containers across the available set to ensure too many containers on same host + // are not used up. Used in yarn mode and in task scheduling (when there are multiple containers available + // to execute a task) + // For example: yarn can returns more containers than we would have requested under ANY, this method + // prioritizes how to use the allocated containers. + // flatten the map such that the array buffer entries are spread out across the returned value. + // given <host, list[container]> == <h1, [c1 .. c5]>, <h2, [c1 .. c3]>, <h3, [c1, c2]>, <h4, c1>, <h5, c1>, i + // the return value would be something like : h1c1, h2c1, h3c1, h4c1, h5c1, h1c2, h2c2, h3c2, h1c3, h2c3, h1c4, h1c5 + // We then 'use' the containers in this order (consuming only the top K from this list where + // K = number to be user). This is to ensure that if we have multiple eligible allocations, + // they dont end up allocating all containers on a small number of hosts - increasing probability of + // multiple container failure when a host goes down. + // Note, there is bias for keys with higher number of entries in value to be picked first (by design) + // Also note that invocation of this method is expected to have containers of same 'type' + // (host-local, rack-local, off-rack) and not across types : so that reordering is simply better from + // the available list - everything else being same. + // That is, we we first consume data local, then rack local and finally off rack nodes. So the + // prioritization from this method applies to within each category + def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { + val _keyList = new ArrayBuffer[K](map.size) + _keyList ++= map.keys + + // order keyList based on population of value in map + val keyList = _keyList.sortWith( + (left, right) => map.get(left).getOrElse(Set()).size > map.get(right).getOrElse(Set()).size + ) + + val retval = new ArrayBuffer[T](keyList.size * 2) + var index = 0 + var found = true + + while (found){ + found = false + for (key <- keyList) { + val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) + assert(containerList != null) + // Get the index'th entry for this host - if present + if (index < containerList.size){ + retval += containerList.apply(index) + found = true + } + } + index += 1 + } + + retval.toList } } diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala new file mode 100644 index 0000000000..d72b0bfc9f --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -0,0 +1,747 @@ +package spark.scheduler.cluster + +import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.math.max +import scala.math.min + +import spark._ +import spark.scheduler._ +import spark.TaskState.TaskState +import java.nio.ByteBuffer + +private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging { + + // process local is expected to be used ONLY within tasksetmanager for now. + val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value + + type TaskLocality = Value + + def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { + + // Must not be the constraint. + assert (constraint != TaskLocality.PROCESS_LOCAL) + + constraint match { + case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL + case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL + // For anything else, allow + case _ => true + } + } + + def parse(str: String): TaskLocality = { + // better way to do this ? + try { + val retval = TaskLocality.withName(str) + // Must not specify PROCESS_LOCAL ! + assert (retval != TaskLocality.PROCESS_LOCAL) + + retval + } catch { + case nEx: NoSuchElementException => { + logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL"); + // default to preserve earlier behavior + NODE_LOCAL + } + } + } +} + +/** + * Schedules the tasks within a single TaskSet in the ClusterScheduler. + */ +private[spark] class ClusterTaskSetManager( + sched: ClusterScheduler, + val taskSet: TaskSet) + extends TaskSetManager + with Logging { + + // Maximum time to wait to run a task in a preferred location (in ms) + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + + // CPUs to request per task + val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble + + // Maximum times a task is allowed to fail before failing the job + val MAX_TASK_FAILURES = 4 + + // Quantile of tasks at which to start speculation + val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble + val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble + + // Serializer for closures and tasks. + val ser = SparkEnv.get.closureSerializer.newInstance() + + val tasks = taskSet.tasks + val numTasks = tasks.length + val copiesRunning = new Array[Int](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) + var tasksFinished = 0 + + var weight = 1 + var minShare = 0 + var runningTasks = 0 + var priority = taskSet.priority + var stageId = taskSet.stageId + var name = "TaskSet_"+taskSet.stageId.toString + var parent:Schedulable = null + + // Last time when we launched a preferred task (for delay scheduling) + var lastPreferredLaunchTime = System.currentTimeMillis + + // List of pending tasks for each node (process local to container). These collections are actually + // treated as stacks, in which new tasks are added to the end of the + // ArrayBuffer and removed from the end. This makes it faster to detect + // tasks that repeatedly fail because whenever a task failed, it is put + // back at the head of the stack. They are also only cleaned up lazily; + // when a task is launched, it remains in all the pending lists except + // the one that it was launched from, but gets removed from them later. + private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node. + // Essentially, similar to pendingTasksForHostPort, except at host level + private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node based on rack locality. + // Essentially, similar to pendingTasksForHost, except at rack level + private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // List containing pending tasks with no locality preferences + val pendingTasksWithNoPrefs = new ArrayBuffer[Int] + + // List containing all pending tasks (also used as a stack, as above) + val allPendingTasks = new ArrayBuffer[Int] + + // Tasks that can be speculated. Since these will be a small fraction of total + // tasks, we'll just hold them in a HashSet. + val speculatableTasks = new HashSet[Int] + + // Task index, start and finish time for each task attempt (indexed by task ID) + val taskInfos = new HashMap[Long, TaskInfo] + + // Did the job fail? + var failed = false + var causeOfFailure = "" + + // How frequently to reprint duplicate exceptions in full, in milliseconds + val EXCEPTION_PRINT_INTERVAL = + System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong + // Map of recent exceptions (identified by string representation and + // top stack frame) to duplicate count (how many times the same + // exception has appeared) and time the full exception was + // printed. This should ideally be an LRU map that can drop old + // exceptions automatically. + val recentExceptions = HashMap[String, (Int, Long)]() + + // Figure out the current map output tracker generation and set it on all tasks + val generation = sched.mapOutputTracker.getGeneration + logDebug("Generation for " + taskSet.id + ": " + generation) + for (t <- tasks) { + t.generation = generation + } + + // Add all our tasks to the pending lists. We do this in reverse order + // of task index so that tasks with low indices get launched first. + for (i <- (0 until numTasks).reverse) { + addPendingTask(i) + } + + // Note that it follows the hierarchy. + // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and + // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL + private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, + taskLocality: TaskLocality.TaskLocality): HashSet[String] = { + + if (TaskLocality.PROCESS_LOCAL == taskLocality) { + // straight forward comparison ! Special case it. + val retval = new HashSet[String]() + scheduler.synchronized { + for (location <- _taskPreferredLocations) { + if (scheduler.isExecutorAliveOnHostPort(location)) { + retval += location + } + } + } + + return retval + } + + val taskPreferredLocations = + if (TaskLocality.NODE_LOCAL == taskLocality) { + _taskPreferredLocations + } else { + assert (TaskLocality.RACK_LOCAL == taskLocality) + // Expand set to include all 'seen' rack local hosts. + // This works since container allocation/management happens within master - so any rack locality information is updated in msater. + // Best case effort, and maybe sort of kludge for now ... rework it later ? + val hosts = new HashSet[String] + _taskPreferredLocations.foreach(h => { + val rackOpt = scheduler.getRackForHost(h) + if (rackOpt.isDefined) { + val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get) + if (hostsOpt.isDefined) { + hosts ++= hostsOpt.get + } + } + + // Ensure that irrespective of what scheduler says, host is always added ! + hosts += h + }) + + hosts + } + + val retval = new HashSet[String] + scheduler.synchronized { + for (prefLocation <- taskPreferredLocations) { + val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1) + if (aliveLocationsOpt.isDefined) { + retval ++= aliveLocationsOpt.get + } + } + } + + retval + } + + // Add a task to all the pending-task lists that it should be on. + private def addPendingTask(index: Int) { + // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate + // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. + val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL) + val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) + val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + + if (rackLocalLocations.size == 0) { + // Current impl ensures this. + assert (processLocalLocations.size == 0) + assert (hostLocalLocations.size == 0) + pendingTasksWithNoPrefs += index + } else { + + // process local locality + for (hostPort <- processLocalLocations) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer()) + hostPortList += index + } + + // host locality (includes process local) + for (hostPort <- hostLocalLocations) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + val host = Utils.parseHostPort(hostPort)._1 + val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + hostList += index + } + + // rack locality (includes process local and host local) + for (rackLocalHostPort <- rackLocalLocations) { + // DEBUG Code + Utils.checkHostPort(rackLocalHostPort) + + val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1 + val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer()) + list += index + } + } + + allPendingTasks += index + } + + // Return the pending tasks list for a given host port (process local), or an empty list if + // there is no map entry for that host + private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = { + // DEBUG Code + Utils.checkHostPort(hostPort) + pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer()) + } + + // Return the pending tasks list for a given host, or an empty list if + // there is no map entry for that host + private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 + pendingTasksForHost.getOrElse(host, ArrayBuffer()) + } + + // Return the pending tasks (rack level) list for a given host, or an empty list if + // there is no map entry for that host + private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 + pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer()) + } + + // Number of pending tasks for a given host Port (which would be process local) + def numPendingTasksForHostPort(hostPort: String): Int = { + getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Number of pending tasks for a given host (which would be data local) + def numPendingTasksForHost(hostPort: String): Int = { + getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Number of pending rack local tasks for a given host + def numRackLocalPendingTasksForHost(hostPort: String): Int = { + getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + + // Dequeue a pending task from the given list and return its index. + // Return None if the list is empty. + // This method also cleans up any tasks in the list that have already + // been launched, since we want that to happen lazily. + private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { + while (!list.isEmpty) { + val index = list.last + list.trimEnd(1) + if (copiesRunning(index) == 0 && !finished(index)) { + return Some(index) + } + } + return None + } + + // Return a speculative task for a given host if any are available. The task should not have an + // attempt running on this host, in case the host is slow. In addition, if locality is set, the + // task must have a preference for this host/rack/no preferred locations at all. + private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + + assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) + speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set + + if (speculatableTasks.size > 0) { + val localTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) + val attemptLocs = taskAttempts(index).map(_.hostPort) + (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) + } + + if (localTask != None) { + speculatableTasks -= localTask.get + return localTask + } + + // check for rack locality + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + val attemptLocs = taskAttempts(index).map(_.hostPort) + locations.contains(hostPort) && !attemptLocs.contains(hostPort) + } + + if (rackTask != None) { + speculatableTasks -= rackTask.get + return rackTask + } + } + + // Any task ... + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + // Check for attemptLocs also ? + val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort)) + if (nonLocalTask != None) { + speculatableTasks -= nonLocalTask.get + return nonLocalTask + } + } + } + return None + } + + // Dequeue a pending task for a given node and return its index. + // If localOnly is set to false, allow non-local tasks as well. + private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort)) + if (processLocalTask != None) { + return processLocalTask + } + + val localTask = findTaskFromList(getPendingTasksForHost(hostPort)) + if (localTask != None) { + return localTask + } + + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort)) + if (rackLocalTask != None) { + return rackLocalTask + } + } + + // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner. + // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down). + val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs) + if (noPrefTask != None) { + return noPrefTask + } + + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + val nonLocalTask = findTaskFromList(allPendingTasks) + if (nonLocalTask != None) { + return nonLocalTask + } + } + + // Finally, if all else has failed, find a speculative task + return findSpeculativeTask(hostPort, locality) + } + + private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = { + Utils.checkHostPort(hostPort) + + val locs = task.preferredLocations + + locs.contains(hostPort) + } + + private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = { + val locs = task.preferredLocations + + // If no preference, consider it as host local + if (locs.isEmpty) return true + + val host = Utils.parseHostPort(hostPort)._1 + locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined + } + + // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location). + // This is true if either the task has preferred locations and this host is one, or it has + // no preferred locations (in which we still count the launch as preferred). + private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = { + + val locs = task.preferredLocations + + val preferredRacks = new HashSet[String]() + for (preferredHost <- locs) { + val rack = sched.getRackForHost(preferredHost) + if (None != rack) preferredRacks += rack.get + } + + if (preferredRacks.isEmpty) return false + + val hostRack = sched.getRackForHost(hostPort) + + return None != hostRack && preferredRacks.contains(hostRack.get) + } + + // Respond to an offer of a single slave from the scheduler by finding a task + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + + if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { + // If explicitly specified, use that + val locality = if (overrideLocality != null) overrideLocality else { + // expand only if we have waited for more than LOCALITY_WAIT for a host local task ... + val time = System.currentTimeMillis + if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY + } + + findTask(hostPort, locality) match { + case Some(index) => { + // Found a task; do some bookkeeping and return a Mesos task for it + val task = tasks(index) + val taskId = sched.newTaskId() + // Figure out whether this should count as a preferred launch + val taskLocality = + if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else + if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else + if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else + TaskLocality.ANY + val prefStr = taskLocality.toString + logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( + taskSet.id, index, taskId, execId, hostPort, prefStr)) + // Do various bookkeeping + copiesRunning(index) += 1 + val time = System.currentTimeMillis + val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality) + taskInfos(taskId) = info + taskAttempts(index) = info :: taskAttempts(index) + if (TaskLocality.NODE_LOCAL == taskLocality) { + lastPreferredLaunchTime = time + } + // Serialize and return the task + val startTime = System.currentTimeMillis + val serializedTask = Task.serializeWithDependencies( + task, sched.sc.addedFiles, sched.sc.addedJars, ser) + val timeTaken = System.currentTimeMillis - startTime + increaseRunningTasks(1) + logInfo("Serialized task %s:%d as %d bytes in %d ms".format( + taskSet.id, index, serializedTask.limit, timeTaken)) + val taskName = "task %s:%d".format(taskSet.id, index) + return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) + } + case _ => + } + } + return None + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + state match { + case TaskState.FINISHED => + taskFinished(tid, state, serializedData) + case TaskState.LOST => + taskLost(tid, state, serializedData) + case TaskState.FAILED => + taskLost(tid, state, serializedData) + case TaskState.KILLED => + taskLost(tid, state, serializedData) + case _ => + } + } + + def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markSuccessful() + decreaseRunningTasks(1) + if (!finished(index)) { + tasksFinished += 1 + logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( + tid, info.duration, tasksFinished, numTasks)) + // Deserialize task result and pass it to the scheduler + try { + val result = ser.deserialize[TaskResult[_]](serializedData) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + } catch { + case cnf: ClassNotFoundException => + val loader = Thread.currentThread().getContextClassLoader + throw new SparkException("ClassNotFound with classloader: " + loader, cnf) + case ex => throw ex + } + // Mark finished and stop if we've finished all the tasks + finished(index) = true + if (tasksFinished == numTasks) { + sched.taskSetFinished(this) + } + } else { + logInfo("Ignoring task-finished event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markFailed() + decreaseRunningTasks(1) + if (!finished(index)) { + logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + copiesRunning(index) -= 1 + // Check if the problem is a map output fetch failure. In that case, this + // task will never succeed on any node, so tell the scheduler about it. + if (serializedData != null && serializedData.limit() > 0) { + val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) + reason match { + case fetchFailed: FetchFailed => + logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) + finished(index) = true + tasksFinished += 1 + sched.taskSetFinished(this) + decreaseRunningTasks(runningTasks) + return + + case taskResultTooBig: TaskResultTooBigFailure => + logInfo("Loss was due to task %s result exceeding Akka frame size; " + + "aborting job".format(tid)) + abort("Task %s result exceeded Akka frame size".format(tid)) + return + + case ef: ExceptionFailure => + val key = ef.description + val now = System.currentTimeMillis + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { + recentExceptions(key) = (0, now) + (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) + } + } else { + recentExceptions(key) = (0, now) + (true, 0) + } + } + if (printFull) { + val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s\n%s".format( + ef.className, ef.description, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + } + + case _ => {} + } + } + // On non-fetch failures, re-enqueue the task as pending for a max number of retries + addPendingTask(index) + // Count failed attempts only on FAILED and LOST state (not on KILLED) + if (state == TaskState.FAILED || state == TaskState.LOST) { + numFailures(index) += 1 + if (numFailures(index) > MAX_TASK_FAILURES) { + logError("Task %s:%d failed more than %d times; aborting job".format( + taskSet.id, index, MAX_TASK_FAILURES)) + abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) + } + } + } else { + logInfo("Ignoring task-lost event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def error(message: String) { + // Save the error message + abort("Error: " + message) + } + + def abort(message: String) { + failed = true + causeOfFailure = message + // TODO: Kill running tasks if we were not terminated due to a Mesos error + sched.listener.taskSetFailed(taskSet, message) + decreaseRunningTasks(runningTasks) + sched.taskSetFinished(this) + } + + override def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + override def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + //TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def addSchedulable(schedulable:Schedulable) { + //nothing + } + + override def removeSchedulable(schedulable:Schedulable) { + //nothing + } + + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + override def executorLost(execId: String, hostPort: String) { + logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) + + // If some task has preferred locations only on hostname, and there are no more executors there, + // put it in the no-prefs list to avoid the wait from delay scheduling + + // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to + // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc. + // Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if + // there is no host local node for the task (not if there is no process local node for the task) + for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) { + // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) + if (newLocs.isEmpty) { + pendingTasksWithNoPrefs += index + } + } + + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage + if (tasks(0).isInstanceOf[ShuffleMapTask]) { + for ((tid, info) <- taskInfos if info.executorId == execId) { + val index = taskInfos(tid).index + if (finished(index)) { + finished(index) = false + copiesRunning(index) -= 1 + tasksFinished -= 1 + addPendingTask(index) + // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our + // stage finishes when a total of tasks.size tasks finish. + sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) + } + } + } + // Also re-enqueue any tasks that were running on the node + for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { + taskLost(tid, TaskState.KILLED, null) + } + } + + /** + * Check for tasks to be speculated and return true if there are any. This is called periodically + * by the ClusterScheduler. + * + * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that + * we don't scan the whole task set. It might also help to make this sorted by launch time. + */ + override def checkSpeculatableTasks(): Boolean = { + // Can't speculate if we only have one task, or if all tasks have finished. + if (numTasks == 1 || tasksFinished == numTasks) { + return false + } + var foundTasks = false + val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt + logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksFinished >= minFinishedForSpeculation) { + val time = System.currentTimeMillis() + val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray + Arrays.sort(durations) + val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) + val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) + // TODO: Threshold should also look at standard deviation of task durations and have a lower + // bound based on that. + logDebug("Task length threshold for speculation: " + threshold) + for ((tid, info) <- taskInfos) { + val index = info.index + if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && + !speculatableTasks.contains(index)) { + logInfo( + "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( + taskSet.id, index, info.hostPort, threshold)) + speculatableTasks += index + foundTasks = true + } + } + } + return foundTasks + } + + override def hasPendingTasks(): Boolean = { + numTasks > 0 && tasksFinished < numTasks + } +} diff --git a/core/src/main/scala/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/spark/scheduler/cluster/Pool.scala new file mode 100644 index 0000000000..941ba7a3f1 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/Pool.scala @@ -0,0 +1,104 @@ +package spark.scheduler.cluster + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import spark.Logging +import spark.scheduler.cluster.SchedulingMode.SchedulingMode + +/** + * An Schedulable entity that represent collection of Pools or TaskSetManagers + */ + +private[spark] class Pool( + val poolName: String, + val schedulingMode: SchedulingMode, + initMinShare: Int, + initWeight: Int) + extends Schedulable + with Logging { + + var schedulableQueue = new ArrayBuffer[Schedulable] + var schedulableNameToSchedulable = new HashMap[String, Schedulable] + + var weight = initWeight + var minShare = initMinShare + var runningTasks = 0 + + var priority = 0 + var stageId = 0 + var name = poolName + var parent:Schedulable = null + + var taskSetSchedulingAlgorithm: SchedulingAlgorithm = { + schedulingMode match { + case SchedulingMode.FAIR => + new FairSchedulingAlgorithm() + case SchedulingMode.FIFO => + new FIFOSchedulingAlgorithm() + } + } + + override def addSchedulable(schedulable: Schedulable) { + schedulableQueue += schedulable + schedulableNameToSchedulable(schedulable.name) = schedulable + schedulable.parent= this + } + + override def removeSchedulable(schedulable: Schedulable) { + schedulableQueue -= schedulable + schedulableNameToSchedulable -= schedulable.name + } + + override def getSchedulableByName(schedulableName: String): Schedulable = { + if (schedulableNameToSchedulable.contains(schedulableName)) { + return schedulableNameToSchedulable(schedulableName) + } + for (schedulable <- schedulableQueue) { + var sched = schedulable.getSchedulableByName(schedulableName) + if (sched != null) { + return sched + } + } + return null + } + + override def executorLost(executorId: String, host: String) { + schedulableQueue.foreach(_.executorLost(executorId, host)) + } + + override def checkSpeculatableTasks(): Boolean = { + var shouldRevive = false + for (schedulable <- schedulableQueue) { + shouldRevive |= schedulable.checkSpeculatableTasks() + } + return shouldRevive + } + + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + val sortedSchedulableQueue = schedulableQueue.sortWith(taskSetSchedulingAlgorithm.comparator) + for (schedulable <- sortedSchedulableQueue) { + sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue() + } + return sortedTaskSetQueue + } + + override def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + override def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + override def hasPendingTasks(): Boolean = { + schedulableQueue.exists(_.hasPendingTasks()) + } +} diff --git a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala new file mode 100644 index 0000000000..2dd9c0564f --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala @@ -0,0 +1,27 @@ +package spark.scheduler.cluster + +import scala.collection.mutable.ArrayBuffer + +/** + * An interface for schedulable entities. + * there are two type of Schedulable entities(Pools and TaskSetManagers) + */ +private[spark] trait Schedulable { + var parent: Schedulable + def weight: Int + def minShare: Int + def runningTasks: Int + def priority: Int + def stageId: Int + def name: String + + def increaseRunningTasks(taskNum: Int): Unit + def decreaseRunningTasks(taskNum: Int): Unit + def addSchedulable(schedulable: Schedulable): Unit + def removeSchedulable(schedulable: Schedulable): Unit + def getSchedulableByName(name: String): Schedulable + def executorLost(executorId: String, host: String): Unit + def checkSpeculatableTasks(): Boolean + def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] + def hasPendingTasks(): Boolean +} diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala new file mode 100644 index 0000000000..18cc15c2a5 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala @@ -0,0 +1,115 @@ +package spark.scheduler.cluster + +import java.io.{File, FileInputStream, FileOutputStream} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.util.control.Breaks._ +import scala.xml._ + +import spark.Logging +import spark.scheduler.cluster.SchedulingMode.SchedulingMode + +import java.util.Properties + +/** + * An interface to build Schedulable tree + * buildPools: build the tree nodes(pools) + * addTaskSetManager: build the leaf nodes(TaskSetManagers) + */ +private[spark] trait SchedulableBuilder { + def buildPools() + def addTaskSetManager(manager: Schedulable, properties: Properties) +} + +private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) extends SchedulableBuilder with Logging { + + override def buildPools() { + //nothing + } + + override def addTaskSetManager(manager: Schedulable, properties: Properties) { + rootPool.addSchedulable(manager) + } +} + +private[spark] class FairSchedulableBuilder(val rootPool: Pool) extends SchedulableBuilder with Logging { + + val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file","unspecified") + val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.cluster.fair.pool" + val DEFAULT_POOL_NAME = "default" + val MINIMUM_SHARES_PROPERTY = "minShare" + val SCHEDULING_MODE_PROPERTY = "schedulingMode" + val WEIGHT_PROPERTY = "weight" + val POOL_NAME_PROPERTY = "@name" + val POOLS_PROPERTY = "pool" + val DEFAULT_SCHEDULING_MODE = SchedulingMode.FIFO + val DEFAULT_MINIMUM_SHARE = 2 + val DEFAULT_WEIGHT = 1 + + override def buildPools() { + val file = new File(schedulerAllocFile) + if (file.exists()) { + val xml = XML.loadFile(file) + for (poolNode <- (xml \\ POOLS_PROPERTY)) { + + val poolName = (poolNode \ POOL_NAME_PROPERTY).text + var schedulingMode = DEFAULT_SCHEDULING_MODE + var minShare = DEFAULT_MINIMUM_SHARE + var weight = DEFAULT_WEIGHT + + val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text + if (xmlSchedulingMode != "") { + try { + schedulingMode = SchedulingMode.withName(xmlSchedulingMode) + } catch { + case e: Exception => logInfo("Error xml schedulingMode, using default schedulingMode") + } + } + + val xmlMinShare = (poolNode \ MINIMUM_SHARES_PROPERTY).text + if (xmlMinShare != "") { + minShare = xmlMinShare.toInt + } + + val xmlWeight = (poolNode \ WEIGHT_PROPERTY).text + if (xmlWeight != "") { + weight = xmlWeight.toInt + } + + val pool = new Pool(poolName, schedulingMode, minShare, weight) + rootPool.addSchedulable(pool) + logInfo("Create new pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format( + poolName, schedulingMode, minShare, weight)) + } + } + + //finally create "default" pool + if (rootPool.getSchedulableByName(DEFAULT_POOL_NAME) == null) { + val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) + rootPool.addSchedulable(pool) + logInfo("Create default pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format( + DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + } + } + + override def addTaskSetManager(manager: Schedulable, properties: Properties) { + var poolName = DEFAULT_POOL_NAME + var parentPool = rootPool.getSchedulableByName(poolName) + if (properties != null) { + poolName = properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME) + parentPool = rootPool.getSchedulableByName(poolName) + if (parentPool == null) { + //we will create a new pool that user has configured in app instead of being defined in xml file + parentPool = new Pool(poolName,DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) + rootPool.addSchedulable(parentPool) + logInfo("Create pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format( + poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + } + } + parentPool.addSchedulable(manager) + logInfo("Added task set " + manager.name + " tasks to pool "+poolName) + } +} diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala index 9ac875de3a..8844057a5c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala @@ -1,6 +1,6 @@ package spark.scheduler.cluster -import spark.Utils +import spark.{SparkContext, Utils} /** * A backend interface for cluster scheduling systems that allows plugging in different ones under @@ -14,14 +14,7 @@ private[spark] trait SchedulerBackend { def defaultParallelism(): Int // Memory used by each executor (in megabytes) - protected val executorMemory = { - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - Option(System.getProperty("spark.executor.memory")) - .orElse(Option(System.getenv("SPARK_MEM"))) - .map(Utils.memoryStringToMb) - .getOrElse(512) - } - + protected val executorMemory: Int = SparkContext.executorMemoryRequested // TODO: Probably want to add a killTask too } diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala new file mode 100644 index 0000000000..f33310a34a --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala @@ -0,0 +1,64 @@ +package spark.scheduler.cluster + +/** + * An interface for sort algorithm + * FIFO: FIFO algorithm between TaskSetManagers + * FS: FS algorithm between Pools, and FIFO or FS within Pools + */ +private[spark] trait SchedulingAlgorithm { + def comparator(s1: Schedulable, s2: Schedulable): Boolean +} + +private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { + override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { + val priority1 = s1.priority + val priority2 = s2.priority + var res = math.signum(priority1 - priority2) + if (res == 0) { + val stageId1 = s1.stageId + val stageId2 = s2.stageId + res = math.signum(stageId1 - stageId2) + } + if (res < 0) { + return true + } else { + return false + } + } +} + +private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { + override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { + val minShare1 = s1.minShare + val minShare2 = s2.minShare + val runningTasks1 = s1.runningTasks + val runningTasks2 = s2.runningTasks + val s1Needy = runningTasks1 < minShare1 + val s2Needy = runningTasks2 < minShare2 + val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0).toDouble + val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble + val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble + val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble + var res:Boolean = true + var compare:Int = 0 + + if (s1Needy && !s2Needy) { + return true + } else if (!s1Needy && s2Needy) { + return false + } else if (s1Needy && s2Needy) { + compare = minShareRatio1.compareTo(minShareRatio2) + } else { + compare = taskToWeightRatio1.compareTo(taskToWeightRatio2) + } + + if (compare < 0) { + return true + } else if (compare > 0) { + return false + } else { + return s1.name < s2.name + } + } +} + diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala new file mode 100644 index 0000000000..6e0c6793e0 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala @@ -0,0 +1,7 @@ +package spark.scheduler.cluster + +object SchedulingMode extends Enumeration("FAIR","FIFO"){ + + type SchedulingMode = Value + val FAIR,FIFO = Value +} diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index bb289c9cf3..170ede0f44 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -31,7 +31,8 @@ private[spark] class SparkDeploySchedulerBackend( val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse( throw new IllegalArgumentException("must supply spark home for spark standalone")) - val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome) + val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, + sc.ui.appUIAddress) client = new Client(sc.env.actorSystem, master, appDesc, this) client.start() @@ -57,9 +58,9 @@ private[spark] class SparkDeploySchedulerBackend( } } - override def executorAdded(executorId: String, workerId: String, host: String, cores: Int, memory: Int) { - logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format( - executorId, host, cores, Utils.memoryMegabytesToString(memory))) + override def executorAdded(executorId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { + logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( + executorId, hostPort, cores, Utils.memoryMegabytesToString(memory))) } override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) { diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala index d766067824..3335294844 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -3,6 +3,7 @@ package spark.scheduler.cluster import spark.TaskState.TaskState import java.nio.ByteBuffer import spark.util.SerializableBuffer +import spark.Utils private[spark] sealed trait StandaloneClusterMessage extends Serializable @@ -19,8 +20,10 @@ case class RegisterExecutorFailed(message: String) extends StandaloneClusterMess // Executors to driver private[spark] -case class RegisterExecutor(executorId: String, host: String, cores: Int) - extends StandaloneClusterMessage +case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) + extends StandaloneClusterMessage { + Utils.checkHostPort(hostPort, "Expected host port") +} private[spark] case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer) diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index a06d853b46..16131215c8 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -6,8 +6,9 @@ import akka.actor._ import scala.concurrent.duration._ import akka.pattern.ask -import spark.{SparkException, Logging, TaskState} +import spark.{Utils, SparkException, Logging, TaskState} import scala.concurrent.Await + import java.util.concurrent.atomic.AtomicInteger import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} @@ -24,12 +25,12 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor var totalCoreCount = new AtomicInteger(0) class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { - val executorActor = new HashMap[String, ActorRef] - val executorAddress = new HashMap[String, Address] - val executorHost = new HashMap[String, String] - val freeCores = new HashMap[String, Int] - val actorToExecutorId = new HashMap[ActorRef, String] - val addressToExecutorId = new HashMap[Address, String] + private val executorActor = new HashMap[String, ActorRef] + private val executorAddress = new HashMap[String, Address] + private val executorHostPort = new HashMap[String, String] + private val freeCores = new HashMap[String, Int] + private val actorToExecutorId = new HashMap[ActorRef, String] + private val addressToExecutorId = new HashMap[Address, String] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -37,7 +38,8 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } def receive = { - case RegisterExecutor(executorId, host, cores) => + case RegisterExecutor(executorId, hostPort, cores) => + Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorActor.contains(executorId)) { sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) } else { @@ -45,7 +47,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor sender ! RegisteredExecutor(sparkProperties) context.watch(sender) executorActor(executorId) = sender - executorHost(executorId) = host + executorHostPort(executorId) = hostPort freeCores(executorId) = cores executorAddress(executorId) = sender.path.address actorToExecutorId(sender) = executorId @@ -85,13 +87,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Make fake resource offers on all executors def makeOffers() { launchTasks(scheduler.resourceOffers( - executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) + executorHostPort.toArray.map {case (id, hostPort) => new WorkerOffer(id, hostPort, freeCores(id))})) } // Make fake resource offers on just one executor def makeOffers(executorId: String) { launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId))))) + Seq(new WorkerOffer(executorId, executorHostPort(executorId), freeCores(executorId))))) } // Launch tasks returned by a set of resource offers @@ -110,9 +112,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor actorToExecutorId -= executorActor(executorId) addressToExecutorId -= executorAddress(executorId) executorActor -= executorId - executorHost -= executorId + executorHostPort -= executorId freeCores -= executorId - executorHost -= executorId + executorHostPort -= executorId totalCoreCount.addAndGet(-numCores) scheduler.executorLost(executorId, SlaveLost(reason)) } @@ -128,7 +130,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor while (iterator.hasNext) { val entry = iterator.next val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.")) { + if (key.startsWith("spark.") && !key.equals("spark.hostPort")) { properties += ((key, value)) } } @@ -136,10 +138,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) } + private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + override def stop() { try { if (driverActor != null) { - val timeout = 5.seconds val future = driverActor.ask(StopDriver)(timeout) Await.result(future, timeout) } @@ -159,7 +162,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - val timeout = 5.seconds val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) Await.result(future, timeout) } catch { diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala index dfe3c5a85b..718f26bfbd 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala @@ -1,5 +1,7 @@ package spark.scheduler.cluster +import spark.Utils + /** * Information about a running task attempt inside a TaskSet. */ @@ -9,8 +11,11 @@ class TaskInfo( val index: Int, val launchTime: Long, val executorId: String, - val host: String, - val preferred: Boolean) { + val hostPort: String, + val taskLocality: TaskLocality.TaskLocality) { + + Utils.checkHostPort(hostPort, "Expected hostport") + var finishTime: Long = 0 var failed = false diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index c9f2c48804..b4dd75d90f 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -1,430 +1,17 @@ package spark.scheduler.cluster -import java.util.Arrays -import java.util.{HashMap => JHashMap} - import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min - -import spark._ import spark.scheduler._ import spark.TaskState.TaskState import java.nio.ByteBuffer -/** - * Schedules the tasks within a single TaskSet in the ClusterScheduler. - */ -private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging { - - // Maximum time to wait to run a task in a preferred location (in ms) - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong - - // CPUs to request per task - val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble - - // Maximum times a task is allowed to fail before failing the job - val MAX_TASK_FAILURES = 4 - - // Quantile of tasks at which to start speculation - val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble - val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble - - // Serializer for closures and tasks. - val ser = SparkEnv.get.closureSerializer.newInstance() - - val priority = taskSet.priority - val tasks = taskSet.tasks - val numTasks = tasks.length - val copiesRunning = new Array[Int](numTasks) - val finished = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) - var tasksFinished = 0 - - // Last time when we launched a preferred task (for delay scheduling) - var lastPreferredLaunchTime = System.currentTimeMillis - - // List of pending tasks for each node. These collections are actually - // treated as stacks, in which new tasks are added to the end of the - // ArrayBuffer and removed from the end. This makes it faster to detect - // tasks that repeatedly fail because whenever a task failed, it is put - // back at the head of the stack. They are also only cleaned up lazily; - // when a task is launched, it remains in all the pending lists except - // the one that it was launched from, but gets removed from them later. - val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] - - // List containing pending tasks with no locality preferences - val pendingTasksWithNoPrefs = new ArrayBuffer[Int] - - // List containing all pending tasks (also used as a stack, as above) - val allPendingTasks = new ArrayBuffer[Int] - - // Tasks that can be speculated. Since these will be a small fraction of total - // tasks, we'll just hold them in a HashSet. - val speculatableTasks = new HashSet[Int] - - // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[Long, TaskInfo] - - // Did the job fail? - var failed = false - var causeOfFailure = "" - - // How frequently to reprint duplicate exceptions in full, in milliseconds - val EXCEPTION_PRINT_INTERVAL = - System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong - // Map of recent exceptions (identified by string representation and - // top stack frame) to duplicate count (how many times the same - // exception has appeared) and time the full exception was - // printed. This should ideally be an LRU map that can drop old - // exceptions automatically. - val recentExceptions = HashMap[String, (Int, Long)]() - - // Figure out the current map output tracker generation and set it on all tasks - val generation = sched.mapOutputTracker.getGeneration - logDebug("Generation for " + taskSet.id + ": " + generation) - for (t <- tasks) { - t.generation = generation - } - - // Add all our tasks to the pending lists. We do this in reverse order - // of task index so that tasks with low indices get launched first. - for (i <- (0 until numTasks).reverse) { - addPendingTask(i) - } - - // Add a task to all the pending-task lists that it should be on. - private def addPendingTask(index: Int) { - val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive - if (locations.size == 0) { - pendingTasksWithNoPrefs += index - } else { - for (host <- locations) { - val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) - list += index - } - } - allPendingTasks += index - } - - // Return the pending tasks list for a given host, or an empty list if - // there is no map entry for that host - private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { - pendingTasksForHost.getOrElse(host, ArrayBuffer()) - } - - // Dequeue a pending task from the given list and return its index. - // Return None if the list is empty. - // This method also cleans up any tasks in the list that have already - // been launched, since we want that to happen lazily. - private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { - while (!list.isEmpty) { - val index = list.last - list.trimEnd(1) - if (copiesRunning(index) == 0 && !finished(index)) { - return Some(index) - } - } - return None - } - - // Return a speculative task for a given host if any are available. The task should not have an - // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the - // task must have a preference for this host (or no preferred locations at all). - private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { - val hostsAlive = sched.hostsAlive - speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set - val localTask = speculatableTasks.find { - index => - val locations = tasks(index).preferredLocations.toSet & hostsAlive - val attemptLocs = taskAttempts(index).map(_.host) - (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host) - } - if (localTask != None) { - speculatableTasks -= localTask.get - return localTask - } - if (!localOnly && speculatableTasks.size > 0) { - val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host)) - if (nonLocalTask != None) { - speculatableTasks -= nonLocalTask.get - return nonLocalTask - } - } - return None - } - - // Dequeue a pending task for a given node and return its index. - // If localOnly is set to false, allow non-local tasks as well. - private def findTask(host: String, localOnly: Boolean): Option[Int] = { - val localTask = findTaskFromList(getPendingTasksForHost(host)) - if (localTask != None) { - return localTask - } - val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs) - if (noPrefTask != None) { - return noPrefTask - } - if (!localOnly) { - val nonLocalTask = findTaskFromList(allPendingTasks) - if (nonLocalTask != None) { - return nonLocalTask - } - } - // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(host, localOnly) - } - - // Does a host count as a preferred location for a task? This is true if - // either the task has preferred locations and this host is one, or it has - // no preferred locations (in which we still count the launch as preferred). - private def isPreferredLocation(task: Task[_], host: String): Boolean = { - val locs = task.preferredLocations - return (locs.contains(host) || locs.isEmpty) - } - - // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { - if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { - val time = System.currentTimeMillis - val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) - - findTask(host, localOnly) match { - case Some(index) => { - // Found a task; do some bookkeeping and return a Mesos task for it - val task = tasks(index) - val taskId = sched.newTaskId() - // Figure out whether this should count as a preferred launch - val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) { - "preferred" - } else { - "non-preferred, not one of " + task.preferredLocations.mkString(", ") - } - logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( - taskSet.id, index, taskId, execId, host, prefStr)) - // Do various bookkeeping - copiesRunning(index) += 1 - val info = new TaskInfo(taskId, index, time, execId, host, preferred) - taskInfos(taskId) = info - taskAttempts(index) = info :: taskAttempts(index) - if (preferred) { - lastPreferredLaunchTime = time - } - // Serialize and return the task - val startTime = System.currentTimeMillis - val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - val timeTaken = System.currentTimeMillis - startTime - logInfo("Serialized task %s:%d as %d bytes in %d ms".format( - taskSet.id, index, serializedTask.limit, timeTaken)) - val taskName = "task %s:%d".format(taskSet.id, index) - return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) - } - case _ => - } - } - return None - } - - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - state match { - case TaskState.FINISHED => - taskFinished(tid, state, serializedData) - case TaskState.LOST => - taskLost(tid, state, serializedData) - case TaskState.FAILED => - taskLost(tid, state, serializedData) - case TaskState.KILLED => - taskLost(tid, state, serializedData) - case _ => - } - } - - def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markSuccessful() - if (!finished(index)) { - tasksFinished += 1 - logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( - tid, info.duration, tasksFinished, numTasks)) - // Deserialize task result and pass it to the scheduler - val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) - result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) - // Mark finished and stop if we've finished all the tasks - finished(index) = true - if (tasksFinished == numTasks) { - sched.taskSetFinished(this) - } - } else { - logInfo("Ignoring task-finished event for TID " + tid + - " because task " + index + " is already finished") - } - } - - def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markFailed() - if (!finished(index)) { - logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - copiesRunning(index) -= 1 - // Check if the problem is a map output fetch failure. In that case, this - // task will never succeed on any node, so tell the scheduler about it. - if (serializedData != null && serializedData.limit() > 0) { - val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) - reason match { - case fetchFailed: FetchFailed => - logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) - finished(index) = true - tasksFinished += 1 - sched.taskSetFinished(this) - return - - case ef: ExceptionFailure => - val key = ef.exception.toString - val now = System.currentTimeMillis - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { - recentExceptions(key) = (0, now) - (true, 0) - } - } - if (printFull) { - val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n"))) - } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount)) - } - - case _ => {} - } - } - // On non-fetch failures, re-enqueue the task as pending for a max number of retries - addPendingTask(index) - // Count failed attempts only on FAILED and LOST state (not on KILLED) - if (state == TaskState.FAILED || state == TaskState.LOST) { - numFailures(index) += 1 - if (numFailures(index) > MAX_TASK_FAILURES) { - logError("Task %s:%d failed more than %d times; aborting job".format( - taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) - } - } - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") - } - } - - def error(message: String) { - // Save the error message - abort("Error: " + message) - } - - def abort(message: String) { - failed = true - causeOfFailure = message - // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.listener.taskSetFailed(taskSet, message) - sched.taskSetFinished(this) - } - - def executorLost(execId: String, hostname: String) { - logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - val newHostsAlive = sched.hostsAlive - // If some task has preferred locations only on hostname, and there are no more executors there, - // put it in the no-prefs list to avoid the wait from delay scheduling - if (!newHostsAlive.contains(hostname)) { - for (index <- getPendingTasksForHost(hostname)) { - val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive - if (newLocs.isEmpty) { - pendingTasksWithNoPrefs += index - } - } - } - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage - if (tasks(0).isInstanceOf[ShuffleMapTask]) { - for ((tid, info) <- taskInfos if info.executorId == execId) { - val index = taskInfos(tid).index - if (finished(index)) { - finished(index) = false - copiesRunning(index) -= 1 - tasksFinished -= 1 - addPendingTask(index) - // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our - // stage finishes when a total of tasks.size tasks finish. - sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) - } - } - } - // Also re-enqueue any tasks that were running on the node - for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - taskLost(tid, TaskState.KILLED, null) - } - } - - /** - * Check for tasks to be speculated and return true if there are any. This is called periodically - * by the ClusterScheduler. - * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. - */ - def checkSpeculatableTasks(): Boolean = { - // Can't speculate if we only have one task, or if all tasks have finished. - if (numTasks == 1 || tasksFinished == numTasks) { - return false - } - var foundTasks = false - val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt - logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) - if (tasksFinished >= minFinishedForSpeculation) { - val time = System.currentTimeMillis() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) - val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) - // TODO: Threshold should also look at standard deviation of task durations and have a lower - // bound based on that. - logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { - val index = info.index - if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && - !speculatableTasks.contains(index)) { - logInfo( - "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.host, threshold)) - speculatableTasks += index - foundTasks = true - } - } - } - return foundTasks - } +private[spark] trait TaskSetManager extends Schedulable { + def taskSet: TaskSet + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, + overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] + def numPendingTasksForHostPort(hostPort: String): Int + def numRackLocalPendingTasksForHost(hostPort :String): Int + def numPendingTasksForHost(hostPort: String): Int + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) + def error(message: String) } diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala index 3c3afcbb14..c47824315c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala +++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala @@ -4,5 +4,5 @@ package spark.scheduler.cluster * Represents free resources available on an executor. */ private[spark] -class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) { +class WorkerOffer(val executorId: String, val hostPort: String, val cores: Int) { } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 9e1bde3fbe..93d4318b29 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -2,19 +2,50 @@ package spark.scheduler.local import java.io.File import java.util.concurrent.atomic.AtomicInteger +import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet import spark._ +import spark.TaskState.TaskState import spark.executor.ExecutorURLClassLoader import spark.scheduler._ -import spark.scheduler.cluster.TaskInfo +import spark.scheduler.cluster._ +import akka.actor._ /** - * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally + * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally * the scheduler also allows each task to fail up to maxFailures times, which is useful for * testing fault recovery. */ -private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) + +private[spark] case class LocalReviveOffers() +private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) + +private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { + def receive = { + case LocalReviveOffers => + launchTask(localScheduler.resourceOffer(freeCores)) + case LocalStatusUpdate(taskId, state, serializeData) => + freeCores += 1 + localScheduler.statusUpdate(taskId, state, serializeData) + launchTask(localScheduler.resourceOffer(freeCores)) + } + + def launchTask(tasks : Seq[TaskDescription]) { + for (task <- tasks) { + freeCores -= 1 + localScheduler.threadPool.submit(new Runnable { + def run() { + localScheduler.runTask(task.taskId,task.serializedTask) + } + }) + } + } +} + +private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler with Logging { @@ -30,87 +61,127 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) - // TODO: Need to take into account stage priority in scheduling + var schedulableBuilder: SchedulableBuilder = null + var rootPool: Pool = null + val activeTaskSets = new HashMap[String, TaskSetManager] + val taskIdToTaskSetId = new HashMap[Long, String] + val taskSetTaskIds = new HashMap[String, HashSet[Long]] + + var localActor: ActorRef = null + + override def start() { + //default scheduler is FIFO + val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO") + //temporarily set rootPool name to empty + rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0) + schedulableBuilder = { + schedulingMode match { + case "FIFO" => + new FIFOSchedulableBuilder(rootPool) + case "FAIR" => + new FairSchedulableBuilder(rootPool) + } + } + schedulableBuilder.buildPools() - override def start() { } + localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") + } override def setListener(listener: TaskSchedulerListener) { this.listener = listener } override def submitTasks(taskSet: TaskSet) { - val tasks = taskSet.tasks - val failCount = new Array[Int](tasks.size) - - def submitTask(task: Task[_], idInJob: Int) { - val myAttemptId = attemptId.getAndIncrement() - threadPool.submit(new Runnable { - def run() { - runTask(task, idInJob, myAttemptId) - } - }) + synchronized { + var manager = new LocalTaskSetManager(this, taskSet) + schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) + activeTaskSets(taskSet.id) = manager + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + localActor ! LocalReviveOffers } + } + + def resourceOffer(freeCores: Int): Seq[TaskDescription] = { + synchronized { + var freeCpuCores = freeCores + val tasks = new ArrayBuffer[TaskDescription](freeCores) + val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() + for (manager <- sortedTaskSetQueue) { + logDebug("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) + } - def runTask(task: Task[_], idInJob: Int, attemptId: Int) { - logInfo("Running " + task) - val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local", true) - // Set the Spark execution environment for the worker thread - SparkEnv.set(env) - try { - Accumulators.clear() - Thread.currentThread().setContextClassLoader(classLoader) - - // Serialize and deserialize the task so that accumulators are changed to thread-local ones; - // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. - val ser = SparkEnv.get.closureSerializer.newInstance() - val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser) - logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes") - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) - updateDependencies(taskFiles, taskJars) // Download any files added with addFile - val deserStart = System.currentTimeMillis() - val deserializedTask = ser.deserialize[Task[_]]( - taskBytes, Thread.currentThread.getContextClassLoader) - val deserTime = System.currentTimeMillis() - deserStart - - // Run it - val result: Any = deserializedTask.run(attemptId) - - // Serialize and deserialize the result to emulate what the Mesos - // executor does. This is useful to catch serialization errors early - // on in development (so when users move their local Spark programs - // to the cluster, they don't get surprised by serialization errors). - val serResult = ser.serialize(result) - deserializedTask.metrics.get.resultSize = serResult.limit() - val resultToReturn = ser.deserialize[Any](serResult) - val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( - ser.serialize(Accumulators.values)) - logInfo("Finished " + task) - info.markSuccessful() - deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough - deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt - - // If the threadpool has not already been shutdown, notify DAGScheduler - if (!Thread.currentThread().isInterrupted) - listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null)) - } catch { - case t: Throwable => { - logError("Exception in task " + idInJob, t) - failCount.synchronized { - failCount(idInJob) += 1 - if (failCount(idInJob) <= maxFailures) { - submitTask(task, idInJob) - } else { - // TODO: Do something nicer here to return all the way to the user - if (!Thread.currentThread().isInterrupted) - listener.taskEnded(task, new ExceptionFailure(t), null, null, info, null) + var launchTask = false + for (manager <- sortedTaskSetQueue) { + do { + launchTask = false + manager.slaveOffer(null,null,freeCpuCores) match { + case Some(task) => + tasks += task + taskIdToTaskSetId(task.taskId) = manager.taskSet.id + taskSetTaskIds(manager.taskSet.id) += task.taskId + freeCpuCores -= 1 + launchTask = true + case None => {} } - } - } + } while(launchTask) } + return tasks } + } - for ((task, i) <- tasks.zipWithIndex) { - submitTask(task, i) + def taskSetFinished(manager: TaskSetManager) { + synchronized { + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds -= manager.taskSet.id + } + } + + def runTask(taskId: Long, bytes: ByteBuffer) { + logInfo("Running " + taskId) + val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + // Set the Spark execution environment for the worker thread + SparkEnv.set(env) + val ser = SparkEnv.get.closureSerializer.newInstance() + try { + Accumulators.clear() + Thread.currentThread().setContextClassLoader(classLoader) + + // Serialize and deserialize the task so that accumulators are changed to thread-local ones; + // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. + val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) + updateDependencies(taskFiles, taskJars) // Download any files added with addFile + val deserStart = System.currentTimeMillis() + val deserializedTask = ser.deserialize[Task[_]]( + taskBytes, Thread.currentThread.getContextClassLoader) + val deserTime = System.currentTimeMillis() - deserStart + + // Run it + val result: Any = deserializedTask.run(taskId) + + // Serialize and deserialize the result to emulate what the Mesos + // executor does. This is useful to catch serialization errors early + // on in development (so when users move their local Spark programs + // to the cluster, they don't get surprised by serialization errors). + val serResult = ser.serialize(result) + deserializedTask.metrics.get.resultSize = serResult.limit() + val resultToReturn = ser.deserialize[Any](serResult) + val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( + ser.serialize(Accumulators.values)) + logInfo("Finished " + taskId) + deserializedTask.metrics.get.executorRunTime = deserTime.toInt//info.duration.toInt //close enough + deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt + + val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null)) + val serializedResult = ser.serialize(taskResult) + localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult) + } catch { + case t: Throwable => { + val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace) + localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure)) + } } } @@ -126,6 +197,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentFiles(name) = timestamp } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) @@ -141,7 +213,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon } } - override def stop() { + def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) { + synchronized { + val taskSetId = taskIdToTaskSetId(taskId) + val taskSetManager = activeTaskSets(taskSetId) + taskSetTaskIds(taskSetId) -= taskId + taskSetManager.statusUpdate(taskId, state, serializedData) + } + } + + override def stop() { threadPool.shutdownNow() } diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala new file mode 100644 index 0000000000..70b69bb26f --- /dev/null +++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala @@ -0,0 +1,172 @@ +package spark.scheduler.local + +import java.io.File +import java.util.concurrent.atomic.AtomicInteger +import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + +import spark._ +import spark.TaskState.TaskState +import spark.scheduler._ +import spark.scheduler.cluster._ + +private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging { + var parent: Schedulable = null + var weight: Int = 1 + var minShare: Int = 0 + var runningTasks: Int = 0 + var priority: Int = taskSet.priority + var stageId: Int = taskSet.stageId + var name: String = "TaskSet_"+taskSet.stageId.toString + + + var failCount = new Array[Int](taskSet.tasks.size) + val taskInfos = new HashMap[Long, TaskInfo] + val numTasks = taskSet.tasks.size + var numFinished = 0 + val ser = SparkEnv.get.closureSerializer.newInstance() + val copiesRunning = new Array[Int](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val MAX_TASK_FAILURES = sched.maxFailures + + def increaseRunningTasks(taskNum: Int): Unit = { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + def decreaseRunningTasks(taskNum: Int): Unit = { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + def addSchedulable(schedulable: Schedulable): Unit = { + //nothing + } + + def removeSchedulable(schedulable: Schedulable): Unit = { + //nothing + } + + def getSchedulableByName(name: String): Schedulable = { + return null + } + + def executorLost(executorId: String, host: String): Unit = { + //nothing + } + + def checkSpeculatableTasks(): Boolean = { + return true + } + + def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + def hasPendingTasks(): Boolean = { + return true + } + + def findTask(): Option[Int] = { + for (i <- 0 to numTasks-1) { + if (copiesRunning(i) == 0 && !finished(i)) { + return Some(i) + } + } + return None + } + + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + SparkEnv.set(sched.env) + logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks)) + if (availableCpus > 0 && numFinished < numTasks) { + findTask() match { + case Some(index) => + val taskId = sched.attemptId.getAndIncrement() + val task = taskSet.tasks(index) + val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + taskInfos(taskId) = info + val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) + logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") + val taskName = "task %s:%d".format(taskSet.id, index) + copiesRunning(index) += 1 + increaseRunningTasks(1) + return Some(new TaskDescription(taskId, null, taskName, bytes)) + case None => {} + } + } + return None + } + + def numPendingTasksForHostPort(hostPort: String): Int = { + return 0 + } + + def numRackLocalPendingTasksForHost(hostPort :String): Int = { + return 0 + } + + def numPendingTasksForHost(hostPort: String): Int = { + return 0 + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + state match { + case TaskState.FINISHED => + taskEnded(tid, state, serializedData) + case TaskState.FAILED => + taskFailed(tid, state, serializedData) + case _ => {} + } + } + + def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markSuccessful() + val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) + numFinished += 1 + decreaseRunningTasks(1) + finished(index) = true + if (numFinished == numTasks) { + sched.taskSetFinished(this) + } + } + + def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markFailed() + decreaseRunningTasks(1) + val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader) + if (!finished(index)) { + copiesRunning(index) -= 1 + numFailures(index) += 1 + val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n"))) + if (numFailures(index) > MAX_TASK_FAILURES) { + val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description) + decreaseRunningTasks(runningTasks) + sched.listener.taskSetFailed(taskSet, errorMessage) + // need to delete failed Taskset from schedule queue + sched.taskSetFinished(this) + } + } + } + + def error(message: String) { + } +} diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala index aca86ab6f0..2ad73b711d 100644 --- a/core/src/main/scala/spark/serializer/Serializer.scala +++ b/core/src/main/scala/spark/serializer/Serializer.scala @@ -1,10 +1,13 @@ package spark.serializer -import java.nio.ByteBuffer import java.io.{EOFException, InputStream, OutputStream} +import java.nio.ByteBuffer + import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream + import spark.util.ByteBufferInputStream + /** * A serializer. Because some serialization libraries are not thread safe, this class is used to * create [[spark.serializer.SerializerInstance]] objects that do the actual serialization and are @@ -14,6 +17,7 @@ trait Serializer { def newInstance(): SerializerInstance } + /** * An instance of a serializer, for use by one thread at a time. */ @@ -45,6 +49,7 @@ trait SerializerInstance { } } + /** * A stream for writing serialized objects. */ @@ -61,6 +66,7 @@ trait SerializationStream { } } + /** * A stream for reading serialized objects. */ diff --git a/core/src/main/scala/spark/serializer/SerializerManager.scala b/core/src/main/scala/spark/serializer/SerializerManager.scala new file mode 100644 index 0000000000..60b2aac797 --- /dev/null +++ b/core/src/main/scala/spark/serializer/SerializerManager.scala @@ -0,0 +1,45 @@ +package spark.serializer + +import java.util.concurrent.ConcurrentHashMap + + +/** + * A service that returns a serializer object given the serializer's class name. If a previous + * instance of the serializer object has been created, the get method returns that instead of + * creating a new one. + */ +private[spark] class SerializerManager { + + private val serializers = new ConcurrentHashMap[String, Serializer] + private var _default: Serializer = _ + + def default = _default + + def setDefault(clsName: String): Serializer = { + _default = get(clsName) + _default + } + + def get(clsName: String): Serializer = { + if (clsName == null) { + default + } else { + var serializer = serializers.get(clsName) + if (serializer != null) { + // If the serializer has been created previously, reuse that. + serializer + } else this.synchronized { + // Otherwise, create a new one. But make sure no other thread has attempted + // to create another new one at the same time. + serializer = serializers.get(clsName) + if (serializer == null) { + val clsLoader = Thread.currentThread.getContextClassLoader + serializer = + Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer] + serializers.put(clsName, serializer) + } + serializer + } + } + } +} diff --git a/core/src/main/scala/spark/storage/BlockException.scala b/core/src/main/scala/spark/storage/BlockException.scala new file mode 100644 index 0000000000..f275d476df --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockException.scala @@ -0,0 +1,5 @@ +package spark.storage + +private[spark] +case class BlockException(blockId: String, message: String) extends Exception(message) + diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala index 993aece1f7..0718156b1b 100644 --- a/core/src/main/scala/spark/storage/BlockFetchTracker.scala +++ b/core/src/main/scala/spark/storage/BlockFetchTracker.scala @@ -1,10 +1,10 @@ package spark.storage private[spark] trait BlockFetchTracker { - def totalBlocks : Int - def numLocalBlocks: Int - def numRemoteBlocks: Int - def remoteFetchTime : Long - def fetchWaitTime: Long - def remoteBytesRead : Long + def totalBlocks : Int + def numLocalBlocks: Int + def numRemoteBlocks: Int + def remoteFetchTime : Long + def fetchWaitTime: Long + def remoteBytesRead : Long } diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala new file mode 100644 index 0000000000..bec876213e --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -0,0 +1,330 @@ +package spark.storage + +import java.nio.ByteBuffer +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet +import scala.collection.mutable.Queue + +import io.netty.buffer.ByteBuf + +import spark.Logging +import spark.Utils +import spark.SparkException +import spark.network.BufferMessage +import spark.network.ConnectionManagerId +import spark.network.netty.ShuffleCopier +import spark.serializer.Serializer + + +/** + * A block fetcher iterator interface. There are two implementations: + * + * BasicBlockFetcherIterator: uses a custom-built NIO communication layer. + * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer. + * + * Eventually we would like the two to converge and use a single NIO-based communication layer, + * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores), + * NIO would perform poorly and thus the need for the Netty OIO one. + */ + +private[storage] +trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] + with Logging with BlockFetchTracker { + def initialize() +} + + +private[storage] +object BlockFetcherIterator { + + // A request to fetch one or more blocks, complete with their sizes + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { + val size = blocks.map(_._2).sum + } + + // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize + // the block (since we want all deserializaton to happen in the calling thread); can also + // represent a fetch failure if size == -1. + class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { + def failed: Boolean = size == -1 + } + + class BasicBlockFetcherIterator( + private val blockManager: BlockManager, + val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + serializer: Serializer) + extends BlockFetcherIterator { + + import blockManager._ + + private var _remoteBytesRead = 0l + private var _remoteFetchTime = 0l + private var _fetchWaitTime = 0l + + if (blocksByAddress == null) { + throw new IllegalArgumentException("BlocksByAddress is null") + } + + // Total number blocks fetched (local + remote). Also number of FetchResults expected + protected var _numBlocksToFetch = 0 + + protected var startTime = System.currentTimeMillis + + // This represents the number of local blocks, also counting zero-sized blocks + private var numLocal = 0 + // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks + protected val localBlocksToFetch = new ArrayBuffer[String]() + + // This represents the number of remote blocks, also counting zero-sized blocks + private var numRemote = 0 + // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks + protected val remoteBlocksToFetch = new HashSet[String]() + + // A queue to hold our results. + protected val results = new LinkedBlockingQueue[FetchResult] + + // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + // the number of bytes in flight is limited to maxBytesInFlight + private val fetchRequests = new Queue[FetchRequest] + + // Current bytes in flight from our requests + private var bytesInFlight = 0L + + protected def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort)) + val cmId = new ConnectionManagerId(req.address.host, req.address.port) + val blockMessageArray = new BlockMessageArray(req.blocks.map { + case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) + }) + bytesInFlight += req.size + val sizeMap = req.blocks.toMap // so we can look up the size of each blockID + val fetchStart = System.currentTimeMillis() + val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) + future.onSuccess { + case Some(message) => { + val fetchDone = System.currentTimeMillis() + _remoteFetchTime += fetchDone - fetchStart + val bufferMessage = message.asInstanceOf[BufferMessage] + val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + for (blockMessage <- blockMessageArray) { + if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + throw new SparkException( + "Unexpected message " + blockMessage.getType + " received from " + cmId) + } + val blockId = blockMessage.getId + results.put(new FetchResult(blockId, sizeMap(blockId), + () => dataDeserialize(blockId, blockMessage.getData, serializer))) + _remoteBytesRead += req.size + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + } + } + case None => { + logError("Could not get block(s) from " + cmId) + for ((blockId, size) <- req.blocks) { + results.put(new FetchResult(blockId, -1, null)) + } + } + } + } + + protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val remoteRequests = new ArrayBuffer[FetchRequest] + for ((address, blockInfos) <- blocksByAddress) { + if (address == blockManagerId) { + numLocal = blockInfos.size + // Filter out zero-sized blocks + localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) + _numBlocksToFetch += localBlocksToFetch.size + } else { + numRemote += blockInfos.size + // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val minRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(String, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + // Skip empty blocks + if (size > 0) { + curBlocks += ((blockId, size)) + remoteBlocksToFetch += blockId + _numBlocksToFetch += 1 + curRequestSize += size + } else if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } + if (curRequestSize >= minRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curRequestSize = 0 + curBlocks = new ArrayBuffer[(String, Long)] + } + } + // Add in the final request + if (!curBlocks.isEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } + } + } + logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " + + totalBlocks + " blocks") + remoteRequests + } + + protected def getLocalBlocks() { + // Get the local blocks while remote blocks are being fetched. Note that it's okay to do + // these all at once because they will just memory-map some files, so they won't consume + // any memory that might exceed our maxBytesInFlight + for (id <- localBlocksToFetch) { + getLocalFromDisk(id, serializer) match { + case Some(iter) => { + // Pass 0 as size since it's not in flight + results.put(new FetchResult(id, 0, () => iter)) + logDebug("Got local block " + id) + } + case None => { + throw new BlockException(id, "Could not get block " + id + " from local machine") + } + } + } + } + + override def initialize() { + // Split local and remote blocks. + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + + // Send out initial requests for blocks, up to our maxBytesInFlight + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + + val numGets = remoteRequests.size - fetchRequests.size + logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + //an iterator that will read fetched blocks off the queue as they arrive. + @volatile protected var resultsGotten = 0 + + override def hasNext: Boolean = resultsGotten < _numBlocksToFetch + + override def next(): (String, Option[Iterator[Any]]) = { + resultsGotten += 1 + val startFetchWait = System.currentTimeMillis() + val result = results.take() + val stopFetchWait = System.currentTimeMillis() + _fetchWaitTime += (stopFetchWait - startFetchWait) + if (! result.failed) bytesInFlight -= result.size + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } + + // Implementing BlockFetchTracker trait. + override def totalBlocks: Int = numLocal + numRemote + override def numLocalBlocks: Int = numLocal + override def numRemoteBlocks: Int = numRemote + override def remoteFetchTime: Long = _remoteFetchTime + override def fetchWaitTime: Long = _fetchWaitTime + override def remoteBytesRead: Long = _remoteBytesRead + } + // End of BasicBlockFetcherIterator + + class NettyBlockFetcherIterator( + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + serializer: Serializer) + extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { + + import blockManager._ + + val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] + + private def startCopiers(numCopiers: Int): List[_ <: Thread] = { + (for ( i <- Range(0,numCopiers) ) yield { + val copier = new Thread { + override def run(){ + try { + while(!isInterrupted && !fetchRequestsSync.isEmpty) { + sendRequest(fetchRequestsSync.take()) + } + } catch { + case x: InterruptedException => logInfo("Copier Interrupted") + //case _ => throw new SparkException("Exception Throw in Shuffle Copier") + } + } + } + copier.start + copier + }).toList + } + + // keep this to interrupt the threads when necessary + private def stopCopiers() { + for (copier <- copiers) { + copier.interrupt() + } + } + + override protected def sendRequest(req: FetchRequest) { + + def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) { + val fetchResult = new FetchResult(blockId, blockSize, + () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) + results.put(fetchResult) + } + + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host)) + val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort) + val cpier = new ShuffleCopier + cpier.getBlocks(cmId, req.blocks, putResult) + logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) + } + + private var copiers: List[_ <: Thread] = null + + override def initialize() { + // Split Local Remote Blocks and set numBlocksToFetch + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + for (request <- Utils.randomize(remoteRequests)) { + fetchRequestsSync.put(request) + } + + copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) + logInfo("Started " + fetchRequestsSync.size + " remote gets in " + + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + override def next(): (String, Option[Iterator[Any]]) = { + resultsGotten += 1 + val result = results.take() + // If all the results has been retrieved, copiers will exit automatically + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } + } + // End of NettyBlockFetcherIterator +} diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index d3f6cd78dc..4bb4927b4a 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -2,10 +2,8 @@ package spark.storage import java.io.{InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} -import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} -import scala.collection.JavaConversions._ +import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet} import akka.actor.{ActorSystem, Cancellable, Props} import scala.concurrent.{Await, Future} @@ -16,7 +14,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils} +import spark.{Logging, SparkEnv, SparkException, Utils} import spark.network._ import spark.serializer.Serializer import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} @@ -24,30 +22,35 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam import sun.nio.ch.DirectBuffer -private[spark] -case class BlockException(blockId: String, message: String, ex: Exception = null) -extends Exception(message) - -private[spark] -class BlockManager( +private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, - val serializer: Serializer, + val defaultSerializer: Serializer, maxMemory: Long) extends Logging { - class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { - var pending: Boolean = true - var size: Long = -1L - var failed: Boolean = false + private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { + @volatile var pending: Boolean = true + @volatile var size: Long = -1L + @volatile var initThread: Thread = null + @volatile var failed = false + + setInitThread() + + private def setInitThread() { + // Set current thread as init thread - waitForReady will not block this thread + // (in case there is non trivial initialization which ends up calling waitForReady as part of + // initialization itself) + this.initThread = Thread.currentThread() + } /** * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). * Return true if the block is available, false otherwise. */ def waitForReady(): Boolean = { - if (pending) { + if (initThread != Thread.currentThread() && pending) { synchronized { while (pending) this.wait() } @@ -57,35 +60,51 @@ class BlockManager( /** Mark this BlockInfo as ready (i.e. block is finished writing) */ def markReady(sizeInBytes: Long) { + assert (pending) + size = sizeInBytes + initThread = null + failed = false + initThread = null + pending = false synchronized { - pending = false - failed = false - size = sizeInBytes this.notifyAll() } } /** Mark this BlockInfo as ready but failed */ def markFailure() { + assert (pending) + size = 0 + initThread = null + failed = true + initThread = null + pending = false synchronized { - failed = true - pending = false this.notifyAll() } } } + val shuffleBlockManager = new ShuffleBlockManager(this) + private val blockInfo = new TimeStampedHashMap[String, BlockInfo] private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) - private[storage] val diskStore: BlockStore = + private[storage] val diskStore: DiskStore = new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) + // If we use Netty for shuffle, start a new Netty-based shuffle sender service. + private val nettyPort: Int = { + val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean + val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt + if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0 + } + val connectionManager = new ConnectionManager(0) implicit val futureExecContext = connectionManager.futureExecContext val blockManagerId = BlockManagerId( - executorId, connectionManager.id.host, connectionManager.id.port) + executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) @@ -101,7 +120,7 @@ class BlockManager( val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties - val host = System.getProperty("spark.hostname", Utils.localHostName()) + val hostPort = Utils.localHostPort() val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) @@ -212,9 +231,12 @@ class BlockManager( * Tell the master about the current storage status of a block. This will send a block update * message reflecting the current status, *not* the desired storage level in its block info. * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. + * + * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid). + * This ensures that update in master will compensate for the increase in memory on slave. */ - def reportBlockStatus(blockId: String, info: BlockInfo) { - val needReregister = !tryToReportBlockStatus(blockId, info) + def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) { + val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize) if (needReregister) { logInfo("Got told to reregister updating block " + blockId) // Reregistering will report our new block for free. @@ -228,7 +250,7 @@ class BlockManager( * which will be true if the block was successfully recorded and false if * the slave needs to re-register. */ - private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = { + private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { info.level match { case null => @@ -237,7 +259,7 @@ class BlockManager( val inMem = level.useMemory && memoryStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication) - val memSize = if (inMem) memoryStore.getSize(blockId) else 0L + val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L (storageLevel, memSize, diskSize, info.tellMaster) } @@ -250,26 +272,24 @@ class BlockManager( } } - /** - * Get locations of the block. + * Get locations of an array of blocks. */ - def getLocations(blockId: String): Seq[String] = { + def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = { val startTimeMs = System.currentTimeMillis - var managers = master.getLocations(blockId) - val locations = managers.map(_.ip) - logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs)) - return locations + val locations = master.getLocations(blockIds).toArray + logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) + locations } /** - * Get locations of an array of blocks. + * A short-circuited method to get blocks directly from disk. This is used for getting + * shuffle blocks. It is safe to do so without a lock on block info since disk store + * never deletes (recent) items. */ - def getLocations(blockIds: Array[String]): Array[Seq[String]] = { - val startTimeMs = System.currentTimeMillis - val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray - logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) - return locations + def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + diskStore.getValues(blockId, serializer).orElse( + sys.error("Block " + blockId + " not found on disk, though it should be")) } /** @@ -277,18 +297,6 @@ class BlockManager( */ def getLocal(blockId: String): Option[Iterator[Any]] = { logDebug("Getting local block " + blockId) - - // As an optimization for map output fetches, if the block is for a shuffle, return it - // without acquiring a lock; the disk store never deletes (recent) items so this should work - if (blockId.startsWith("shuffle_")) { - return diskStore.getValues(blockId) match { - case Some(iterator) => - Some(iterator) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } - val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { @@ -339,6 +347,8 @@ class BlockManager( case Some(bytes) => // Put a copy of the block back in memory before returning it. Note that we can't // put the ByteBuffer returned by the disk store as that's a memory-mapped file. + // The use of rewind assumes this. + assert (0 == bytes.position()) val copyForMemory = ByteBuffer.allocate(bytes.limit) copyForMemory.put(bytes) memoryStore.putBytes(blockId, copyForMemory, level) @@ -372,7 +382,7 @@ class BlockManager( // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work - if (blockId.startsWith("shuffle_")) { + if (ShuffleBlockManager.isShuffle(blockId)) { return diskStore.getBytes(blockId) match { case Some(bytes) => Some(bytes) @@ -411,6 +421,7 @@ class BlockManager( // Read it as a byte buffer into memory first, then return it diskStore.getBytes(blockId) match { case Some(bytes) => + assert (0 == bytes.position()) if (level.useMemory) { if (level.deserialized) { memoryStore.putBytes(blockId, bytes, level) @@ -450,7 +461,7 @@ class BlockManager( for (loc <- locations) { logDebug("Getting remote block " + blockId + " from " + loc) val data = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port)) + GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) if (data != null) { return Some(dataDeserialize(blockId, data)) } @@ -473,9 +484,19 @@ class BlockManager( * fashion as they're received. Expects a size in bytes to be provided for each block fetched, * so that we can control the maxMegabytesInFlight for the fetch. */ - def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]) + def getMultiple( + blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) : BlockFetcherIterator = { - return new BlockFetcherIterator(this, blocksByAddress) + + val iter = + if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) { + new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) + } else { + new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) + } + + iter.initialize() + iter } def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) @@ -486,6 +507,22 @@ class BlockManager( } /** + * A short circuited method to get a block writer that can write data directly to disk. + * This is currently used for writing shuffle files out. Callers should handle error + * cases. + */ + def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) + : BlockObjectWriter = { + val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize) + writer.registerCloseEventHandler(() => { + val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false) + blockInfo.put(blockId, myInfo) + myInfo.markReady(writer.size()) + }) + writer + } + + /** * Put a new block of values to the block manager. Returns its (estimated) size in bytes. */ def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, @@ -501,17 +538,26 @@ class BlockManager( throw new IllegalArgumentException("Storage level is null or invalid") } - val oldBlock = blockInfo.get(blockId).orNull - if (oldBlock != null && oldBlock.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return oldBlock.size - } - // Remember the block's storage level so that we can correctly drop it to disk if it needs // to be dropped right after it got put into memory. Note, however, that other threads will // not be able to get() this block until we call markReady on its BlockInfo. - val myInfo = new BlockInfo(level, tellMaster) - blockInfo.put(blockId, myInfo) + val myInfo = { + val tinfo = new BlockInfo(level, tellMaster) + // Do atomically ! + val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) + + if (oldBlockOpt.isDefined) { + if (oldBlockOpt.get.waitForReady()) { + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + return oldBlockOpt.get.size + } + + // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? + oldBlockOpt.get + } else { + tinfo + } + } val startTimeMs = System.currentTimeMillis @@ -531,6 +577,7 @@ class BlockManager( logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") + var marked = false try { if (level.useMemory) { // Save it just to memory first, even if it also has useDisk set to true; we will later @@ -555,26 +602,25 @@ class BlockManager( // Now that the block is in either the memory or disk store, let other threads read it, // and tell the master about it. + marked = true myInfo.markReady(size) if (tellMaster) { reportBlockStatus(blockId, myInfo) } - } catch { + } finally { // If we failed at putting the block to memory/disk, notify other possible readers // that it has failed, and then remove it from the block info map. - case e: Exception => { + if (! marked) { // Note that the remove must happen before markFailure otherwise another thread // could've inserted a new BlockInfo before we remove it. blockInfo.remove(blockId) myInfo.markFailure() - logWarning("Putting block " + blockId + " failed", e) - throw e + logWarning("Putting block " + blockId + " failed") } } } logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) - // Replicate block if required if (level.replication > 1) { val remoteStartTime = System.currentTimeMillis @@ -611,16 +657,26 @@ class BlockManager( throw new IllegalArgumentException("Storage level is null or invalid") } - if (blockInfo.contains(blockId)) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return - } - // Remember the block's storage level so that we can correctly drop it to disk if it needs // to be dropped right after it got put into memory. Note, however, that other threads will // not be able to get() this block until we call markReady on its BlockInfo. - val myInfo = new BlockInfo(level, tellMaster) - blockInfo.put(blockId, myInfo) + val myInfo = { + val tinfo = new BlockInfo(level, tellMaster) + // Do atomically ! + val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) + + if (oldBlockOpt.isDefined) { + if (oldBlockOpt.get.waitForReady()) { + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + return + } + + // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? + oldBlockOpt.get + } else { + tinfo + } + } val startTimeMs = System.currentTimeMillis @@ -639,6 +695,7 @@ class BlockManager( logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") + var marked = false try { if (level.useMemory) { // Store it only in memory at first, even if useDisk is also set to true @@ -649,22 +706,24 @@ class BlockManager( diskStore.putBytes(blockId, bytes, level) } + // assert (0 == bytes.position(), "" + bytes) + // Now that the block is in either the memory or disk store, let other threads read it, // and tell the master about it. + marked = true myInfo.markReady(bytes.limit) if (tellMaster) { reportBlockStatus(blockId, myInfo) } - } catch { + } finally { // If we failed at putting the block to memory/disk, notify other possible readers // that it has failed, and then remove it from the block info map. - case e: Exception => { + if (! marked) { // Note that the remove must happen before markFailure otherwise another thread // could've inserted a new BlockInfo before we remove it. blockInfo.remove(blockId) myInfo.markFailure() - logWarning("Putting block " + blockId + " failed", e) - throw e + logWarning("Putting block " + blockId + " failed") } } } @@ -698,7 +757,7 @@ class BlockManager( logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " + data.limit() + " Bytes. To node: " + peer) if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), - new ConnectionManagerId(peer.ip, peer.port))) { + new ConnectionManagerId(peer.host, peer.port))) { logError("Failed to call syncPutBlock to " + peer) } logDebug("Replicated BlockId " + blockId + " once used " + @@ -730,6 +789,14 @@ class BlockManager( val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { + // required ? As of now, this will be invoked only for blocks which are ready + // But in case this changes in future, adding for consistency sake. + if (! info.waitForReady() ) { + // If we get here, the block write failed. + logWarning("Block " + blockId + " was marked as failure. Nothing to drop") + return + } + val level = info.level if (level.useDisk && !diskStore.contains(blockId)) { logInfo("Writing block " + blockId + " to disk") @@ -740,12 +807,13 @@ class BlockManager( diskStore.putBytes(blockId, bytes, level) } } + val droppedMemorySize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L val blockWasRemoved = memoryStore.remove(blockId) if (!blockWasRemoved) { logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") } if (info.tellMaster) { - reportBlockStatus(blockId, info) + reportBlockStatus(blockId, info, droppedMemorySize) } if (!level.useDisk) { // The block is completely gone from this node; forget it so we can put() it again later. @@ -758,9 +826,23 @@ class BlockManager( } /** + * Remove all blocks belonging to the given RDD. + * @return The number of blocks removed. + */ + def removeRdd(rddId: Int): Int = { + // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps + // from RDD.id to blocks. + logInfo("Removing RDD " + rddId) + val rddPrefix = "rdd_" + rddId + "_" + val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1) + blocksToRemove.foreach(blockId => removeBlock(blockId, false)) + blocksToRemove.size + } + + /** * Remove a block from both memory and disk. */ - def removeBlock(blockId: String) { + def removeBlock(blockId: String, tellMaster: Boolean = true) { logInfo("Removing block " + blockId) val info = blockInfo.get(blockId).orNull if (info != null) info.synchronized { @@ -772,7 +854,7 @@ class BlockManager( "the disk or memory store") } blockInfo.remove(blockId) - if (info.tellMaster) { + if (tellMaster && info.tellMaster) { reportBlockStatus(blockId, info) } } else { @@ -805,7 +887,7 @@ class BlockManager( } def shouldCompress(blockId: String): Boolean = { - if (blockId.startsWith("shuffle_")) { + if (ShuffleBlockManager.isShuffle(blockId)) { compressShuffle } else if (blockId.startsWith("broadcast_")) { compressBroadcast @@ -820,7 +902,11 @@ class BlockManager( * Wrap an output stream for compression if block compression is enabled for its block type */ def wrapForCompression(blockId: String, s: OutputStream): OutputStream = { - if (shouldCompress(blockId)) new LZFOutputStream(s) else s + if (shouldCompress(blockId)) { + (new LZFOutputStream(s)).setFinishBlockOnFlush(true) + } else { + s + } } /** @@ -830,7 +916,10 @@ class BlockManager( if (shouldCompress(blockId)) new LZFInputStream(s) else s } - def dataSerialize(blockId: String, values: Iterator[Any]): ByteBuffer = { + def dataSerialize( + blockId: String, + values: Iterator[Any], + serializer: Serializer = defaultSerializer): ByteBuffer = { val byteStream = new FastByteArrayOutputStream(4096) val ser = serializer.newInstance() ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() @@ -842,7 +931,10 @@ class BlockManager( * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of * the iterator is reached. */ - def dataDeserialize(blockId: String, bytes: ByteBuffer): Iterator[Any] = { + def dataDeserialize( + blockId: String, + bytes: ByteBuffer, + serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) serializer.newInstance().deserializeStream(stream).asIterator @@ -862,8 +954,8 @@ class BlockManager( } } -private[spark] -object BlockManager extends Logging { + +private[spark] object BlockManager extends Logging { val ID_GENERATOR = new IdGenerator @@ -873,7 +965,8 @@ object BlockManager extends Logging { } def getHeartBeatFrequencyFromSystemProperties: Long = - System.getProperty("spark.storage.blockManagerHeartBeatMs", "10000").toLong + + System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4 def getDisableHeartBeatsForTesting: Boolean = System.getProperty("spark.test.disableBlockManagerHeartBeat", "false").toBoolean @@ -892,177 +985,43 @@ object BlockManager extends Logging { } } } -} - -class BlockFetcherIterator( - private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] -) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker { - import blockManager._ - - private var _remoteBytesRead = 0l - private var _remoteFetchTime = 0l - private var _fetchWaitTime = 0l - - if (blocksByAddress == null) { - throw new IllegalArgumentException("BlocksByAddress is null") - } - val totalBlocks = blocksByAddress.map(_._2.size).sum - logDebug("Getting " + totalBlocks + " blocks") - var startTime = System.currentTimeMillis - val localBlockIds = new ArrayBuffer[String]() - val remoteBlockIds = new HashSet[String]() - - // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize - // the block (since we want all deserializaton to happen in the calling thread); can also - // represent a fetch failure if size == -1. - class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { - def failed: Boolean = size == -1 - } - - // A queue to hold our results. - val results = new LinkedBlockingQueue[FetchResult] - - // A request to fetch one or more blocks, complete with their sizes - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { - val size = blocks.map(_._2).sum - } + def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): HashMap[String, List[String]] = { + // env == null and blockManagerMaster != null is used in tests + assert (env != null || blockManagerMaster != null) + val locationBlockIds: Seq[Seq[BlockManagerId]] = + if (env != null) { + env.blockManager.getLocationBlockIds(blockIds) + } else { + blockManagerMaster.getLocations(blockIds) + } - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight - val fetchRequests = new Queue[FetchRequest] + // Convert from block master locations to executor locations (we need that for task scheduling) + val executorLocations = new HashMap[String, List[String]]() + for (i <- 0 until blockIds.length) { + val blockId = blockIds(i) + val blockLocations = locationBlockIds(i) - // Current bytes in flight from our requests - var bytesInFlight = 0L + val executors = new HashSet[String]() - def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip)) - val cmId = new ConnectionManagerId(req.address.ip, req.address.port) - val blockMessageArray = new BlockMessageArray(req.blocks.map { - case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) - }) - bytesInFlight += req.size - val sizeMap = req.blocks.toMap // so we can look up the size of each blockID - val fetchStart = System.currentTimeMillis() - val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onSuccess { - case Some(message) => { - val fetchDone = System.currentTimeMillis() - _remoteFetchTime += fetchDone - fetchStart - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new SparkException( - "Unexpected message " + blockMessage.getType + " received from " + cmId) - } - val blockId = blockMessage.getId - results.put(new FetchResult( - blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData))) - _remoteBytesRead += req.size - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + if (env != null) { + for (bkLocation <- blockLocations) { + val executorHostPort = env.resolveExecutorIdToHostPort(bkLocation.executorId, bkLocation.host) + executors += executorHostPort + // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort) } - } - case None => { - logError("Could not get block(s) from " + cmId) - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) + } else { + // Typically while testing, etc - revert to simply using host. + for (bkLocation <- blockLocations) { + executors += bkLocation.host + // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort) } } - } - } - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val remoteRequests = new ArrayBuffer[FetchRequest] - for ((address, blockInfos) <- blocksByAddress) { - if (address == blockManagerId) { - localBlockIds ++= blockInfos.map(_._1) - } else { - remoteBlockIds ++= blockInfos.map(_._1) - // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val minRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - curBlocks += ((blockId, size)) - curRequestSize += size - if (curRequestSize >= minRequestSize) { - // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) - curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] - } - } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } + executorLocations.put(blockId, executors.toSeq.toList) } - } - // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(remoteRequests) - // Send out initial requests for blocks, up to our maxBytesInFlight - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) + executorLocations } - val numGets = remoteBlockIds.size - fetchRequests.size - logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) - - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - startTime = System.currentTimeMillis - for (id <- localBlockIds) { - getLocal(id) match { - case Some(iter) => { - results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight - logDebug("Got local block " + id) - } - case None => { - throw new BlockException(id, "Could not get block " + id + " from local machine") - } - } - } - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - - //an iterator that will read fetched blocks off the queue as they arrive. - var resultsGotten = 0 - - def hasNext: Boolean = resultsGotten < totalBlocks - - def next(): (String, Option[Iterator[Any]]) = { - resultsGotten += 1 - val startFetchWait = System.currentTimeMillis() - val result = results.take() - val stopFetchWait = System.currentTimeMillis() - _fetchWaitTime += (stopFetchWait - startFetchWait) - bytesInFlight -= result.size - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - (result.blockId, if (result.failed) None else Some(result.deserialize())) - } - - - //methods to profile the block fetching - def numLocalBlocks = localBlockIds.size - def numRemoteBlocks = remoteBlockIds.size - - def remoteFetchTime = _remoteFetchTime - def fetchWaitTime = _fetchWaitTime - - def remoteBytesRead = _remoteBytesRead - } diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index f2f1e77d41..1e557d6148 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -2,51 +2,70 @@ package spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap +import spark.Utils /** * This class represent an unique identifier for a BlockManager. * The first 2 constructors of this class is made private to ensure that - * BlockManagerId objects can be created only using the factory method in - * [[spark.storage.BlockManager$]]. This allows de-duplication of ID objects. + * BlockManagerId objects can be created only using the apply method in + * the companion object. This allows de-duplication of ID objects. * Also, constructor parameters are private to ensure that parameters cannot * be modified from outside this class. */ private[spark] class BlockManagerId private ( private var executorId_ : String, - private var ip_ : String, - private var port_ : Int + private var host_ : String, + private var port_ : Int, + private var nettyPort_ : Int ) extends Externalizable { - private def this() = this(null, null, 0) // For deserialization only + private def this() = this(null, null, 0, 0) // For deserialization only def executorId: String = executorId_ - def ip: String = ip_ + if (null != host_){ + Utils.checkHost(host_, "Expected hostname") + assert (port_ > 0) + } + + def hostPort: String = { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + + host + ":" + port + } + + def host: String = host_ def port: Int = port_ + def nettyPort: Int = nettyPort_ + override def writeExternal(out: ObjectOutput) { out.writeUTF(executorId_) - out.writeUTF(ip_) + out.writeUTF(host_) out.writeInt(port_) + out.writeInt(nettyPort_) } override def readExternal(in: ObjectInput) { executorId_ = in.readUTF() - ip_ = in.readUTF() + host_ = in.readUTF() port_ = in.readInt() + nettyPort_ = in.readInt() } @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port) + override def toString = "BlockManagerId(%s, %s, %d, %d)".format(executorId, host, port, nettyPort) - override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port + override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + nettyPort override def equals(that: Any) = that match { case id: BlockManagerId => - executorId == id.executorId && port == id.port && ip == id.ip + executorId == id.executorId && port == id.port && host == id.host && nettyPort == id.nettyPort case _ => false } @@ -55,8 +74,17 @@ private[spark] class BlockManagerId private ( private[spark] object BlockManagerId { - def apply(execId: String, ip: String, port: Int) = - getCachedBlockManagerId(new BlockManagerId(execId, ip, port)) + /** + * Returns a [[spark.storage.BlockManagerId]] for the given configuraiton. + * + * @param execId ID of the executor. + * @param host Host name of the block manager. + * @param port Port of the block manager. + * @param nettyPort Optional port for the Netty-based shuffle sender. + * @return A new [[spark.storage.BlockManagerId]]. + */ + def apply(execId: String, host: String, port: Int, nettyPort: Int) = + getCachedBlockManagerId(new BlockManagerId(execId, host, port, nettyPort)) def apply(in: ObjectInput) = { val obj = new BlockManagerId() @@ -67,11 +95,7 @@ private[spark] object BlockManagerId { val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { - if (blockManagerIdCache.containsKey(id)) { - blockManagerIdCache.get(id) - } else { - blockManagerIdCache.put(id, id) - id - } + blockManagerIdCache.putIfAbsent(id, id) + blockManagerIdCache.get(id) } } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 4e55936d28..6a9278292e 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -9,10 +9,13 @@ import scala.util.Random import akka.actor.{Actor, ActorRef, ActorSystem, Props} import scala.concurrent.Await +import scala.concurrent.Future +import scala.concurrent.ExecutionContext.Implicits.global + import akka.pattern.ask import scala.concurrent.duration._ -import spark.{Logging, SparkException, Utils} +import spark.{Logging, SparkException} private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging { @@ -21,7 +24,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" - val timeout = 10.seconds + val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") /** Remove a dead executor from the driver actor. This is only called on the driver side. */ def removeExecutor(execId: String) { @@ -87,6 +90,19 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi } /** + * Remove all blocks belonging to the given RDD. + */ + def removeRdd(rddId: Int, blocking: Boolean) { + val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) + future onFailure { + case e: Throwable => logError("Failed to remove RDD " + rddId, e) + } + if (blocking) { + Await.result(future, timeout) + } + } + + /** * Return the memory status for each block manager, in the form of a map from * the block manager's id to two long values. The first value is the maximum * amount of memory allocated for the block manager, while the second is the @@ -97,7 +113,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi } def getStorageStatus: Array[StorageStatus] = { - askDriverWithReply[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray + askDriverWithReply[Array[StorageStatus]](GetStorageStatus) } /** Stop the driver actor, called only on the Spark driver node */ @@ -134,7 +150,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi val future = driverActor.ask(message)(timeout) val result = Await.result(future, timeout) if (result == null) { - throw new Exception("BlockManagerMaster returned null") + throw new SparkException("BlockManagerMaster returned null") } return result.asInstanceOf[T] } catch { diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index 2d39e2c15c..6b5e38124b 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -2,14 +2,16 @@ package spark.storage import java.util.{HashMap => JHashMap} -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.util.Random import akka.actor.{Actor, ActorRef, Cancellable} +import akka.pattern.ask + import scala.concurrent.duration._ +import scala.concurrent.Future -import spark.{Logging, Utils} +import spark.{Logging, Utils, SparkException} /** * BlockManagerMasterActor is an actor on the master node to track statuses of @@ -20,13 +22,16 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = - new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] + new mutable.HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] // Mapping from executor ID to block manager ID. - private val blockManagerIdByExecutor = new HashMap[String, BlockManagerId] + private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] // Mapping from block id to the set of block managers that have the block. - private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] + private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]] + + val akkaTimeout = Duration.create( + System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") initLogging() @@ -34,7 +39,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", - "5000").toLong + "60000").toLong var timeoutCheckingTask: Cancellable = null @@ -50,28 +55,34 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { def receive = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => register(blockManagerId, maxMemSize, slaveActor) + sender ! true case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + // TODO: Ideally we want to handle all the message replies in receive instead of in the + // individual private methods. updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) case GetLocations(blockId) => - getLocations(blockId) + sender ! getLocations(blockId) case GetLocationsMultipleBlockIds(blockIds) => - getLocationsMultipleBlockIds(blockIds) + sender ! getLocationsMultipleBlockIds(blockIds) case GetPeers(blockManagerId, size) => - getPeersDeterministic(blockManagerId, size) - /*getPeers(blockManagerId, size)*/ + sender ! getPeers(blockManagerId, size) case GetMemoryStatus => - getMemoryStatus + sender ! memoryStatus case GetStorageStatus => - getStorageStatus + sender ! storageStatus + + case RemoveRdd(rddId) => + sender ! removeRdd(rddId) case RemoveBlock(blockId) => - removeBlock(blockId) + removeBlockFromWorkers(blockId) + sender ! true case RemoveExecutor(execId) => removeExecutor(execId) @@ -81,7 +92,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { logInfo("Stopping BlockManagerMaster") sender ! true if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel + timeoutCheckingTask.cancel() } context.stop(self) @@ -89,13 +100,36 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { expireDeadHosts() case HeartBeat(blockManagerId) => - heartBeat(blockManagerId) + sender ! heartBeat(blockManagerId) case other => - logInfo("Got unknown message: " + other) + logWarning("Got unknown message: " + other) + } + + private def removeRdd(rddId: Int): Future[Seq[Int]] = { + // First remove the metadata for the given RDD, and then asynchronously remove the blocks + // from the slaves. + + val prefix = "rdd_" + rddId + "_" + // Find all blocks for the given RDD, remove the block from both blockLocations and + // the blockManagerInfo that is tracking the blocks. + val blocks = blockLocations.keySet().filter(_.startsWith(prefix)) + blocks.foreach { blockId => + val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) + bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) + blockLocations.remove(blockId) + } + + // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. + // The dispatcher is used as an implicit argument into the Future sequence construction. + import context.dispatcher + val removeMsg = RemoveRdd(rddId) + Future.sequence(blockManagerInfo.values.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + }.toSeq) } - def removeBlockManager(blockManagerId: BlockManagerId) { + private def removeBlockManager(blockManagerId: BlockManagerId) { val info = blockManagerInfo(blockManagerId) // Remove the block manager from blockManagerIdByExecutor. @@ -106,7 +140,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { val iterator = info.blocks.keySet.iterator while (iterator.hasNext) { val blockId = iterator.next - val locations = blockLocations.get(blockId)._2 + val locations = blockLocations.get(blockId) locations -= blockManagerId if (locations.size == 0) { blockLocations.remove(locations) @@ -114,11 +148,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } } - def expireDeadHosts() { + private def expireDeadHosts() { logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.") val now = System.currentTimeMillis() val minSeenTime = now - slaveTimeout - val toRemove = new HashSet[BlockManagerId] + val toRemove = new mutable.HashSet[BlockManagerId] for (info <- blockManagerInfo.values) { if (info.lastSeenMs < minSeenTime) { logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " + @@ -129,31 +163,26 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { toRemove.foreach(removeBlockManager) } - def removeExecutor(execId: String) { + private def removeExecutor(execId: String) { logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) - sender ! true } - def heartBeat(blockManagerId: BlockManagerId) { + private def heartBeat(blockManagerId: BlockManagerId): Boolean = { if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.executorId == "<driver>" && !isLocal) { - sender ! true - } else { - sender ! false - } + blockManagerId.executorId == "<driver>" && !isLocal } else { blockManagerInfo(blockManagerId).updateLastSeenMs() - sender ! true + true } } // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. - private def removeBlock(blockId: String) { - val block = blockLocations.get(blockId) - if (block != null) { - block._2.foreach { blockManagerId: BlockManagerId => + private def removeBlockFromWorkers(blockId: String) { + val locations = blockLocations.get(blockId) + if (locations != null) { + locations.foreach { blockManagerId: BlockManagerId => val blockManager = blockManagerInfo.get(blockManagerId) if (blockManager.isDefined) { // Remove the block from the slave's BlockManager. @@ -163,23 +192,20 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } } } - sender ! true } // Return a map from the block manager id to max memory and remaining memory. - private def getMemoryStatus() { - val res = blockManagerInfo.map { case(blockManagerId, info) => + private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { + blockManagerInfo.map { case(blockManagerId, info) => (blockManagerId, (info.maxMem, info.remainingMem)) }.toMap - sender ! res } - private def getStorageStatus() { - val res = blockManagerInfo.map { case(blockManagerId, info) => + private def storageStatus: Array[StorageStatus] = { + blockManagerInfo.map { case(blockManagerId, info) => import collection.JavaConverters._ StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap) - } - sender ! res + }.toArray } private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { @@ -188,7 +214,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } else if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { case Some(manager) => - // A block manager of the same host name already exists + // A block manager of the same executor already exists. + // This should never happen. Let's just quit. logError("Got two different block manager registrations on " + id.executorId) System.exit(1) case None => @@ -197,7 +224,6 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo( id, System.currentTimeMillis(), maxMemSize, slaveActor) } - sender ! true } private def updateBlockInfo( @@ -226,12 +252,12 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) - var locations: HashSet[BlockManagerId] = null + var locations: mutable.HashSet[BlockManagerId] = null if (blockLocations.containsKey(blockId)) { - locations = blockLocations.get(blockId)._2 + locations = blockLocations.get(blockId) } else { - locations = new HashSet[BlockManagerId] - blockLocations.put(blockId, (storageLevel.replication, locations)) + locations = new mutable.HashSet[BlockManagerId] + blockLocations.put(blockId, locations) } if (storageLevel.isValid) { @@ -247,70 +273,24 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! true } - private def getLocations(blockId: String) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockId + " " - if (blockLocations.containsKey(blockId)) { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockLocations.get(blockId)._2) - sender ! res.toSeq - } else { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - sender ! res - } + private def getLocations(blockId: String): Seq[BlockManagerId] = { + if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } - private def getLocationsMultipleBlockIds(blockIds: Array[String]) { - def getLocations(blockId: String): Seq[BlockManagerId] = { - val tmp = blockId - if (blockLocations.containsKey(blockId)) { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockLocations.get(blockId)._2) - return res.toSeq - } else { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - return res.toSeq - } - } - - var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]] - for (blockId <- blockIds) { - res.append(getLocations(blockId)) - } - sender ! res.toSeq + private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + blockIds.map(blockId => getLocations(blockId)) } - private def getPeers(blockManagerId: BlockManagerId, size: Int) { - var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(peers) - res -= blockManagerId - val rand = new Random(System.currentTimeMillis()) - while (res.length > size) { - res.remove(rand.nextInt(res.length)) - } - sender ! res.toSeq - } - - private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) { - var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + private def getPeers(blockManagerId: BlockManagerId, size: Int): Seq[BlockManagerId] = { + val peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray val selfIndex = peers.indexOf(blockManagerId) if (selfIndex == -1) { - throw new Exception("Self index for " + blockManagerId + " not found") + throw new SparkException("Self index for " + blockManagerId + " not found") } // Note that this logic will select the same node multiple times if there aren't enough peers - var index = selfIndex - while (res.size < size) { - index += 1 - if (index == selfIndex) { - throw new Exception("More peer expected than available") - } - res += peers(index % peers.size) - } - sender ! res.toSeq + Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq } } @@ -333,8 +313,8 @@ object BlockManagerMasterActor { // Mapping from block id to its status. private val _blocks = new JHashMap[String, BlockStatus] - logInfo("Registering block manager %s:%d with %s RAM".format( - blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) + logInfo("Registering block manager %s with %s RAM".format( + blockManagerId.hostPort, Utils.memoryBytesToString(maxMem))) def updateLastSeenMs() { _lastSeenMs = System.currentTimeMillis() @@ -359,13 +339,13 @@ object BlockManagerMasterActor { _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) if (storageLevel.useMemory) { _remainingMem -= memSize - logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + logInfo("Added %s in memory on %s (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize), Utils.memoryBytesToString(_remainingMem))) } if (storageLevel.useDisk) { - logInfo("Added %s on disk on %s:%d (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + logInfo("Added %s on disk on %s (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize))) } } else if (_blocks.containsKey(blockId)) { // If isValid is not true, drop the block. @@ -373,17 +353,24 @@ object BlockManagerMasterActor { _blocks.remove(blockId) if (blockStatus.storageLevel.useMemory) { _remainingMem += blockStatus.memSize - logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize), Utils.memoryBytesToString(_remainingMem))) } if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s:%d on disk (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + logInfo("Removed %s on %s on disk (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize))) } } } + def removeBlock(blockId: String) { + if (_blocks.containsKey(blockId)) { + _remainingMem += _blocks.get(blockId).memSize + _blocks.remove(blockId) + } + } + def remainingMem: Long = _remainingMem def lastSeenMs: Long = _lastSeenMs diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index cff48d9909..0010726c8d 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -16,6 +16,9 @@ sealed trait ToBlockManagerSlave private[spark] case class RemoveBlock(blockId: String) extends ToBlockManagerSlave +// Remove all blocks belonging to a specific RDD. +private[spark] case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. diff --git a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala index f570cdc52d..b264d1deb5 100644 --- a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala @@ -11,6 +11,12 @@ import spark.{Logging, SparkException, Utils} */ class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { override def receive = { - case RemoveBlock(blockId) => blockManager.removeBlock(blockId) + + case RemoveBlock(blockId) => + blockManager.removeBlock(blockId) + + case RemoveRdd(rddId) => + val numBlocksRemoved = blockManager.removeRdd(rddId) + sender ! numBlocksRemoved } } diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index a3397a0fb4..631455abcd 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -1,10 +1,12 @@ package spark.storage import akka.actor.{ActorRef, ActorSystem} + import akka.util.Timeout import scala.concurrent.duration._ import spray.httpx.TwirlSupport._ import spray.routing.Directives + import spark.{Logging, SparkContext} import spark.util.AkkaUtils import spark.Utils @@ -20,20 +22,21 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, implicit val implicitActorSystem = actorSystem val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(10 seconds) + implicit val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + val host = Utils.localHostName() + val port = if (System.getProperty("spark.ui.port") != null) { + System.getProperty("spark.ui.port").toInt + } else { + // TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which + // random port it bound to, so we have to try to find a local one by creating a socket. + Utils.findFreePort() + } /** Start a HTTP server to run the Web interface */ def start() { try { - val port = if (System.getProperty("spark.ui.port") != null) { - System.getProperty("spark.ui.port").toInt - } else { - // TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which - // random port it bound to, so we have to try to find a local one by creating a socket. - Utils.findFreePort() - } - AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler) - logInfo("Started BlockManager web UI at http://%s:%d".format(Utils.localHostName(), port)) + AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler, "BlockManagerHTTPServer") + logInfo("Started BlockManager web UI at http://%s:%d".format(host, port)) } catch { case e: Exception => logError("Failed to create BlockManager WebUI", e) @@ -74,4 +77,6 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, } } } + + private[spark] def appUIAddress = "http://" + host + ":" + port } diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala index d2985559c1..3057ade233 100644 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala @@ -2,13 +2,7 @@ package spark.storage import java.nio.ByteBuffer -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.util.Random - -import spark.{Logging, Utils, SparkEnv} +import spark.{Logging, Utils} import spark.network._ /** @@ -19,7 +13,7 @@ import spark.network._ */ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging { initLogging() - + blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive) def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { @@ -51,7 +45,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends logDebug("Received [" + pB + "]") putBlock(pB.id, pB.data, pB.level) return None - } + } case BlockMessage.TYPE_GET_BLOCK => { val gB = new GetBlock(blockMessage.getId) logDebug("Received [" + gB + "]") @@ -88,30 +82,26 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends private[spark] object BlockManagerWorker extends Logging { private var blockManagerWorker: BlockManagerWorker = null - private val DATA_TRANSFER_TIME_OUT_MS: Long = 500 - private val REQUEST_RETRY_INTERVAL_MS: Long = 1000 - + initLogging() - + def startBlockManagerWorker(manager: BlockManager) { blockManagerWorker = new BlockManagerWorker(manager) } - + def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = { val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val serializer = blockManager.serializer + val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromPutBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) val resultMessage = connectionManager.sendMessageReliablySync( toConnManagerId, blockMessageArray.toBufferMessage) return (resultMessage != None) } - + def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val serializer = blockManager.serializer + val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromGetBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) val responseMessage = connectionManager.sendMessageReliablySync( diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala index a25decb123..ee0c5ff9a2 100644 --- a/core/src/main/scala/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala @@ -115,6 +115,7 @@ private[spark] object BlockMessageArray { val newBuffer = ByteBuffer.allocate(totalSize) newBuffer.clear() bufferMessage.buffers.foreach(buffer => { + assert (0 == buffer.position()) newBuffer.put(buffer) buffer.rewind() }) diff --git a/core/src/main/scala/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/spark/storage/BlockObjectWriter.scala new file mode 100644 index 0000000000..42e2b07d5c --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockObjectWriter.scala @@ -0,0 +1,50 @@ +package spark.storage + +import java.nio.ByteBuffer + + +/** + * An interface for writing JVM objects to some underlying storage. This interface allows + * appending data to an existing block, and can guarantee atomicity in the case of faults + * as it allows the caller to revert partial writes. + * + * This interface does not support concurrent writes. + */ +abstract class BlockObjectWriter(val blockId: String) { + + var closeEventHandler: () => Unit = _ + + def open(): BlockObjectWriter + + def close() { + closeEventHandler() + } + + def isOpen: Boolean + + def registerCloseEventHandler(handler: () => Unit) { + closeEventHandler = handler + } + + /** + * Flush the partial writes and commit them as a single atomic block. Return the + * number of bytes written for this commit. + */ + def commit(): Long + + /** + * Reverts writes that haven't been flushed yet. Callers should invoke this function + * when there are runtime exceptions. + */ + def revertPartialWrites() + + /** + * Writes an object. + */ + def write(value: Any) + + /** + * Size of the valid writes, in bytes. + */ + def size(): Long +} diff --git a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala b/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala deleted file mode 100644 index f6c28dce52..0000000000 --- a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala +++ /dev/null @@ -1,12 +0,0 @@ -package spark.storage - -private[spark] trait DelegateBlockFetchTracker extends BlockFetchTracker { - var delegate : BlockFetchTracker = _ - def setDelegate(d: BlockFetchTracker) {delegate = d} - def totalBlocks = delegate.totalBlocks - def numLocalBlocks = delegate.numLocalBlocks - def numRemoteBlocks = delegate.numRemoteBlocks - def remoteFetchTime = delegate.remoteFetchTime - def fetchWaitTime = delegate.fetchWaitTime - def remoteBytesRead = delegate.remoteBytesRead -} diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index ddbf8821ad..da859eebcb 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -1,41 +1,126 @@ package spark.storage +import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile} import java.nio.ByteBuffer -import java.io.{File, FileOutputStream, RandomAccessFile} +import java.nio.channels.FileChannel import java.nio.channels.FileChannel.MapMode import java.util.{Random, Date} import java.text.SimpleDateFormat -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - import scala.collection.mutable.ArrayBuffer -import spark.executor.ExecutorExitCode +import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import spark.Utils +import spark.executor.ExecutorExitCode +import spark.serializer.{Serializer, SerializationStream} +import spark.Logging +import spark.network.netty.ShuffleSender +import spark.network.netty.PathResolver + /** * Stores BlockManager blocks on disk. */ private class DiskStore(blockManager: BlockManager, rootDirs: String) - extends BlockStore(blockManager) { + extends BlockStore(blockManager) with Logging { + + class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int) + extends BlockObjectWriter(blockId) { + + private val f: File = createFile(blockId /*, allowAppendExisting */) + + // The file channel, used for repositioning / truncating the file. + private var channel: FileChannel = null + private var bs: OutputStream = null + private var objOut: SerializationStream = null + private var lastValidPosition = 0L + private var initialized = false + + override def open(): DiskBlockObjectWriter = { + val fos = new FileOutputStream(f, true) + channel = fos.getChannel() + bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize)) + objOut = serializer.newInstance().serializeStream(bs) + initialized = true + this + } + + override def close() { + if (initialized) { + objOut.close() + bs.close() + channel = null + bs = null + objOut = null + } + // Invoke the close callback handler. + super.close() + } - val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + override def isOpen: Boolean = objOut != null + // Flush the partial writes, and set valid length to be the length of the entire file. + // Return the number of bytes written for this commit. + override def commit(): Long = { + if (initialized) { + // NOTE: Flush the serializer first and then the compressed/buffered output stream + objOut.flush() + bs.flush() + val prevPos = lastValidPosition + lastValidPosition = channel.position() + lastValidPosition - prevPos + } else { + // lastValidPosition is zero if stream is uninitialized + lastValidPosition + } + } + + override def revertPartialWrites() { + if (initialized) { + // Discard current writes. We do this by flushing the outstanding writes and + // truncate the file to the last valid position. + objOut.flush() + bs.flush() + channel.truncate(lastValidPosition) + } + } + + override def write(value: Any) { + if (!initialized) { + open() + } + objOut.writeObject(value) + } + + override def size(): Long = lastValidPosition + } + + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + + private var shuffleSender : ShuffleSender = null // Create one local directory for each path mentioned in spark.local.dir; then, inside this // directory, create multiple subdirectories that we will hash files into, in order to avoid // having really large inodes at the top level. - val localDirs = createLocalDirs() - val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + private val localDirs: Array[File] = createLocalDirs() + private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) addShutdownHook() + def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) + : BlockObjectWriter = { + new DiskBlockObjectWriter(blockId, serializer, bufferSize) + } + override def getSize(blockId: String): Long = { getFile(blockId).length() } - override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + // So that we do not modify the input offsets ! + // duplicate does not copy buffer, so inexpensive + val bytes = _bytes.duplicate() logDebug("Attempting to put block " + blockId) val startTime = System.currentTimeMillis val file = createFile(blockId) @@ -49,6 +134,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) blockId, Utils.memoryBytesToString(bytes.limit), (finishTime - startTime))) } + private def getFileBytes(file: File): ByteBuffer = { + val length = file.length() + val channel = new RandomAccessFile(file, "r").getChannel() + val buffer = try { + channel.map(MapMode.READ_ONLY, 0, length) + } finally { + channel.close() + } + + buffer + } + override def putValues( blockId: String, values: ArrayBuffer[Any], @@ -61,18 +158,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) val file = createFile(blockId) val fileOut = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(new FileOutputStream(file))) - val objOut = blockManager.serializer.newInstance().serializeStream(fileOut) + val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut) objOut.writeAll(values.iterator) objOut.close() val length = file.length() + + val timeTaken = System.currentTimeMillis - startTime logDebug("Block %s stored as %s file on disk in %d ms".format( - blockId, Utils.memoryBytesToString(length), (System.currentTimeMillis - startTime))) + blockId, Utils.memoryBytesToString(length), timeTaken)) if (returnValues) { // Return a byte buffer for the contents of the file - val channel = new RandomAccessFile(file, "r").getChannel() - val buffer = channel.map(MapMode.READ_ONLY, 0, length) - channel.close() + val buffer = getFileBytes(file) PutResult(length, Right(buffer)) } else { PutResult(length, null) @@ -81,10 +178,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) override def getBytes(blockId: String): Option[ByteBuffer] = { val file = getFile(blockId) - val length = file.length().toInt - val channel = new RandomAccessFile(file, "r").getChannel() - val bytes = channel.map(MapMode.READ_ONLY, 0, length) - channel.close() + val bytes = getFileBytes(file) Some(bytes) } @@ -92,11 +186,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) } + /** + * A version of getValues that allows a custom serializer. This is used as part of the + * shuffle short-circuit code. + */ + def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) + } + override def remove(blockId: String): Boolean = { val file = getFile(blockId) if (file.exists()) { file.delete() - true } else { false } @@ -106,10 +207,13 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) getFile(blockId).exists() } - private def createFile(blockId: String): File = { + private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { val file = getFile(blockId) - if (file.exists()) { - throw new Exception("File for block " + blockId + " already exists on disk: " + file) + if (!allowAppendExisting && file.exists()) { + // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task + // was rescheduled on the same machine as the old task. + logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") + file.delete() } file } @@ -144,8 +248,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private def createLocalDirs(): Array[File] = { logDebug("Creating local directories at root dirs '" + rootDirs + "'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - rootDirs.split(",").map(rootDir => { - var foundLocalDir: Boolean = false + rootDirs.split(",").map { rootDir => + var foundLocalDir = false var localDir: File = null var localDirId: String = null var tries = 0 @@ -156,12 +260,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) localDir = new File(rootDir, "spark-local-" + localDirId) if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true + foundLocalDir = localDir.mkdirs() } } catch { case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) + logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) } } if (!foundLocalDir) { @@ -171,19 +274,40 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } logInfo("Created local directory at " + localDir) localDir - }) + } } private def addShutdownHook() { + localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { override def run() { logDebug("Shutdown hook called") - try { - localDirs.foreach(localDir => Utils.deleteRecursively(localDir)) - } catch { - case t: Throwable => logError("Exception while deleting local spark dirs", t) + localDirs.foreach { localDir => + try { + if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + } catch { + case t: Throwable => + logError("Exception while deleting local spark dir: " + localDir, t) + } + } + if (shuffleSender != null) { + shuffleSender.stop } } }) } + + private[storage] def startShuffleBlockSender(port: Int): Int = { + val pResolver = new PathResolver { + override def getAbsolutePath(blockId: String): String = { + if (!blockId.startsWith("shuffle_")) { + return null + } + DiskStore.this.getFile(blockId).getAbsolutePath() + } + } + shuffleSender = new ShuffleSender(port, pResolver) + logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port) + shuffleSender.port + } } diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala index 949588476c..eba5ee507f 100644 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -31,7 +31,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + // Work on a duplicate - since the original input might be used elsewhere. + val bytes = _bytes.duplicate() bytes.rewind() if (level.deserialized) { val values = blockManager.dataDeserialize(blockId, bytes) diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala new file mode 100644 index 0000000000..44638e0c2d --- /dev/null +++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala @@ -0,0 +1,50 @@ +package spark.storage + +import spark.serializer.Serializer + + +private[spark] +class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) + + +private[spark] +trait ShuffleBlocks { + def acquireWriters(mapId: Int): ShuffleWriterGroup + def releaseWriters(group: ShuffleWriterGroup) +} + + +private[spark] +class ShuffleBlockManager(blockManager: BlockManager) { + + def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = { + new ShuffleBlocks { + // Get a group of writers for a map task. + override def acquireWriters(mapId: Int): ShuffleWriterGroup = { + val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 + val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) + blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) + } + new ShuffleWriterGroup(mapId, writers) + } + + override def releaseWriters(group: ShuffleWriterGroup) = { + // Nothing really to release here. + } + } + } +} + + +private[spark] +object ShuffleBlockManager { + + // Returns the block id for a given shuffle block. + def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = { + "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId + } + + // Returns true if the block is a shuffle block. + def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_") +} diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index 3b5a77ab22..cc0c354e7e 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -123,11 +123,7 @@ object StorageLevel { val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = { - if (storageLevelCache.containsKey(level)) { - storageLevelCache.get(level) - } else { - storageLevelCache.put(level, level) - level - } + storageLevelCache.putIfAbsent(level, level) + storageLevelCache.get(level) } } diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index dec47a9d41..950c0cdf35 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -4,9 +4,9 @@ import spark.{Utils, SparkContext} import BlockManagerMasterActor.BlockStatus private[spark] -case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, +case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, blocks: Map[String, BlockStatus]) { - + def memUsed(blockPrefix: String = "") = { blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize). reduceOption(_+_).getOrElse(0l) @@ -22,53 +22,62 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, } case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) { + numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) + extends Ordered[RDDInfo] { override def toString = { import Utils.memoryBytesToString "RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id, storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize)) } + + override def compare(that: RDDInfo) = { + this.id - that.id + } } /* Helper methods for storage-related objects */ private[spark] object StorageUtils { - /* Given the current storage status of the BlockManager, returns information for each RDD */ - def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus], + /* Given the current storage status of the BlockManager, returns information for each RDD */ + def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus], sc: SparkContext) : Array[RDDInfo] = { - rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) + rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) } - /* Given a list of BlockStatus objets, returns information for each RDD */ - def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], + /* Given a list of BlockStatus objets, returns information for each RDD */ + def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], sc: SparkContext) : Array[RDDInfo] = { // Group by rddId, ignore the partition name - val groupedRddBlocks = infos.groupBy { case(k, v) => + val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) => k.substring(0,k.lastIndexOf('_')) }.mapValues(_.values.toArray) // For each RDD, generate an RDDInfo object - groupedRddBlocks.map { case(rddKey, rddBlocks) => - + val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) => // Add up memory and disk sizes val memSize = rddBlocks.map(_.memSize).reduce(_ + _) val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _) // Find the id of the RDD, e.g. rdd_1 => 1 val rddId = rddKey.split("_").last.toInt - // Get the friendly name for the rdd, if available. - val rdd = sc.persistentRdds(rddId) - val rddName = Option(rdd.name).getOrElse(rddKey) - val rddStorageLevel = rdd.getStorageLevel - RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.partitions.size, memSize, diskSize) - }.toArray + // Get the friendly name and storage level for the RDD, if available + sc.persistentRdds.get(rddId).map { r => + val rddName = Option(r.name).getOrElse(rddKey) + val rddStorageLevel = r.getStorageLevel + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) + } + }.flatten.toArray + + scala.util.Sorting.quickSort(rddInfos) + + rddInfos } - /* Removes all BlockStatus object that are not part of a block prefix */ - def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], + /* Removes all BlockStatus object that are not part of a block prefix */ + def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], prefix: String) : Array[StorageStatus] = { storageStatusList.map { status => diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index e16915c8e9..ea39888c21 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -5,13 +5,15 @@ import com.typesafe.config.ConfigFactory import scala.concurrent.duration._ import akka.pattern.ask import akka.remote.RemoteActorRefProvider + import spray.routing.Route import spray.io.IOExtension import spray.routing.HttpServiceActor import spray.can.server.{HttpServer, ServerSettings} import spray.io.SingletonHandler import scala.concurrent.Await -import spark.SparkException +import spark.{Utils, SparkException} + import java.util.concurrent.TimeoutException /** @@ -29,9 +31,14 @@ private[spark] object AkkaUtils { def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt - val akkaTimeout = System.getProperty("spark.akka.timeout", "20").toInt + + val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt + val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt - val lifecycleEvents = System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean + val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" + // 10 seconds is the default akka timeout, but in a cluster, we need higher by default. + val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt + val akkaConf = ConfigFactory.parseString(""" akka.daemonic = on akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] @@ -45,10 +52,11 @@ private[spark] object AkkaUtils { akka.remote.netty.execution-pool-size = %d akka.actor.default-dispatcher.throughput = %d akka.remote.log-remote-lifecycle-events = %s + akka.remote.netty.write-timeout = %ds """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize, - if (lifecycleEvents) "on" else "off")) + lifecycleEvents, akkaWriteTimeout)) - val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader) + val actorSystem = ActorSystem(name, akkaConf) // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a // hack because Akka doesn't let you figure out the port through the public API yet. @@ -60,12 +68,13 @@ private[spark] object AkkaUtils { /** * Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to * handle requests. Returns the bound port or throws a SparkException on failure. + * TODO: Not changing ip to host here - is it required ? */ - def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route) = { + def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, name: String = "HttpServer") = { val ioWorker = IOExtension(actorSystem).ioBridge() val httpService = actorSystem.actorOf(Props(HttpServiceActor(route))) val server = actorSystem.actorOf( - Props(new HttpServer(ioWorker, SingletonHandler(httpService), ServerSettings())), name = "HttpServer") + Props(new HttpServer(ioWorker, SingletonHandler(httpService), ServerSettings())), name = name) actorSystem.registerOnTermination { actorSystem.stop(ioWorker) } val timeout = 3.seconds val future = server.ask(HttpServer.Bind(ip, port))(timeout) diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala new file mode 100644 index 0000000000..4bc5db8bb7 --- /dev/null +++ b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala @@ -0,0 +1,45 @@ +package spark.util + +import java.io.Serializable +import java.util.{PriorityQueue => JPriorityQueue} +import scala.collection.generic.Growable +import scala.collection.JavaConverters._ + +/** + * Bounded priority queue. This class wraps the original PriorityQueue + * class and modifies it such that only the top K elements are retained. + * The top K elements are defined by an implicit Ordering[A]. + */ +class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) + extends Iterable[A] with Growable[A] with Serializable { + + private val underlying = new JPriorityQueue[A](maxSize, ord) + + override def iterator: Iterator[A] = underlying.iterator.asScala + + override def ++=(xs: TraversableOnce[A]): this.type = { + xs.foreach { this += _ } + this + } + + override def +=(elem: A): this.type = { + if (size < maxSize) underlying.offer(elem) + else maybeReplaceLowest(elem) + this + } + + override def +=(elem1: A, elem2: A, elems: A*): this.type = { + this += elem1 += elem2 ++= elems + } + + override def clear() { underlying.clear() } + + private def maybeReplaceLowest(a: A): Boolean = { + val head = underlying.peek() + if (head != null && ord.gt(a, head)) { + underlying.poll() + underlying.offer(a) + } else false + } +} + diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala index 5f80180339..2b980340b7 100644 --- a/core/src/main/scala/spark/util/StatCounter.scala +++ b/core/src/main/scala/spark/util/StatCounter.scala @@ -37,17 +37,23 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { if (other == this) { merge(other.copy()) // Avoid overwriting fields in a weird order } else { - val delta = other.mu - mu - if (other.n * 10 < n) { - mu = mu + (delta * other.n) / (n + other.n) - } else if (n * 10 < other.n) { - mu = other.mu - (delta * n) / (n + other.n) - } else { - mu = (mu * n + other.mu * other.n) / (n + other.n) + if (n == 0) { + mu = other.mu + m2 = other.m2 + n = other.n + } else if (other.n != 0) { + val delta = other.mu - mu + if (other.n * 10 < n) { + mu = mu + (delta * other.n) / (n + other.n) + } else if (n * 10 < other.n) { + mu = other.mu - (delta * n) / (n + other.n) + } else { + mu = (mu * n + other.mu * other.n) / (n + other.n) + } + m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) + n += other.n } - m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) - n += other.n - this + this } } diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index 4afba0a4c3..e95ca1fc8e 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -3,6 +3,7 @@ package spark.util import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConversions import scala.collection.mutable.Map +import spark.scheduler.MapStatus /** * This is a custom implementation of scala.collection.mutable.Map which stores the insertion @@ -42,6 +43,13 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { this } + // Should we return previous value directly or as Option ? + def putIfAbsent(key: A, value: B): Option[B] = { + val prev = internalMap.putIfAbsent(key, (value, currentTime)) + if (prev != null) Some(prev._1) else None + } + + override def -= (key: A): this.type = { internalMap.remove(key) this diff --git a/core/src/main/scala/spark/util/TimedIterator.scala b/core/src/main/scala/spark/util/TimedIterator.scala deleted file mode 100644 index 539b01f4ce..0000000000 --- a/core/src/main/scala/spark/util/TimedIterator.scala +++ /dev/null @@ -1,32 +0,0 @@ -package spark.util - -/** - * A utility for tracking the total time an iterator takes to iterate through its elements. - * - * In general, this should only be used if you expect it to take a considerable amount of time - * (eg. milliseconds) to get each element -- otherwise, the timing won't be very accurate, - * and you are probably just adding more overhead - */ -class TimedIterator[+A](val sub: Iterator[A]) extends Iterator[A] { - private var netMillis = 0l - private var nElems = 0 - def hasNext = { - val start = System.currentTimeMillis() - val r = sub.hasNext - val end = System.currentTimeMillis() - netMillis += (end - start) - r - } - def next = { - val start = System.currentTimeMillis() - val r = sub.next - val end = System.currentTimeMillis() - netMillis += (end - start) - nElems += 1 - r - } - - def getNetMillis = netMillis - def getAverageTimePerItem = netMillis / nElems.toDouble - -} diff --git a/core/src/main/twirl/spark/deploy/master/app_details.scala.html b/core/src/main/twirl/spark/deploy/master/app_details.scala.html index 301a7e2124..5e5e5de551 100644 --- a/core/src/main/twirl/spark/deploy/master/app_details.scala.html +++ b/core/src/main/twirl/spark/deploy/master/app_details.scala.html @@ -9,19 +9,17 @@ <li><strong>ID:</strong> @app.id</li> <li><strong>Description:</strong> @app.desc.name</li> <li><strong>User:</strong> @app.desc.user</li> - <li><strong>Cores:</strong> - @app.desc.cores - (@app.coresGranted Granted - @if(app.desc.cores == Integer.MAX_VALUE) { - + <li><strong>Cores:</strong> + @if(app.desc.maxCores == Integer.MAX_VALUE) { + Unlimited (@app.coresGranted granted) } else { - , @app.coresLeft + @app.desc.maxCores (@app.coresGranted granted, @app.coresLeft left) } - ) </li> <li><strong>Memory per Slave:</strong> @app.desc.memoryPerSlave</li> <li><strong>Submit Date:</strong> @app.submitDate</li> <li><strong>State:</strong> @app.state</li> + <li><strong><a href="@app.appUiUrl">Application Detail UI</a></strong></li> </ul> </div> </div> diff --git a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html index d2d80fad48..21e72c7aab 100644 --- a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html +++ b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html @@ -3,7 +3,7 @@ <tr> <td>@executor.id</td> <td> - <a href="@executor.worker.webUiAddress">@executor.worker.id</href> + <a href="@executor.worker.webUiAddress">@executor.worker.id</a> </td> <td>@executor.cores</td> <td>@executor.memory</td> diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html index ac51a39a51..b9b9f08810 100644 --- a/core/src/main/twirl/spark/deploy/master/index.scala.html +++ b/core/src/main/twirl/spark/deploy/master/index.scala.html @@ -2,7 +2,7 @@ @import spark.deploy.master._ @import spark.Utils -@spark.common.html.layout(title = "Spark Master on " + state.host) { +@spark.common.html.layout(title = "Spark Master on " + state.host + ":" + state.port) { <!-- Cluster Details --> <div class="row"> diff --git a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html index be69e9bf02..46277ca421 100644 --- a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html +++ b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html @@ -4,7 +4,7 @@ <tr> <td> - <a href="@worker.webUiAddress">@worker.id</href> + <a href="@worker.webUiAddress">@worker.id</a> </td> <td>@{worker.host}:@{worker.port}</td> <td>@worker.state</td> diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html index c39f769a73..0e66af9284 100644 --- a/core/src/main/twirl/spark/deploy/worker/index.scala.html +++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html @@ -1,7 +1,7 @@ @(worker: spark.deploy.WorkerState) @import spark.Utils -@spark.common.html.layout(title = "Spark Worker on " + worker.host) { +@spark.common.html.layout(title = "Spark Worker on " + worker.host + ":" + worker.port) { <!-- Worker Details --> <div class="row"> diff --git a/core/src/main/twirl/spark/storage/worker_table.scala.html b/core/src/main/twirl/spark/storage/worker_table.scala.html index d54b8de4cc..cd72a688c1 100644 --- a/core/src/main/twirl/spark/storage/worker_table.scala.html +++ b/core/src/main/twirl/spark/storage/worker_table.scala.html @@ -12,7 +12,7 @@ <tbody> @for(status <- workersStatusList) { <tr> - <td>@(status.blockManagerId.ip + ":" + status.blockManagerId.port)</td> + <td>@(status.blockManagerId.host + ":" + status.blockManagerId.port)</td> <td> @(Utils.memoryBytesToString(status.memUsed(prefix))) (@(Utils.memoryBytesToString(status.memRemaining)) Total Available) diff --git a/core/src/test/resources/fairscheduler.xml b/core/src/test/resources/fairscheduler.xml new file mode 100644 index 0000000000..5a688b0ebb --- /dev/null +++ b/core/src/test/resources/fairscheduler.xml @@ -0,0 +1,14 @@ +<allocations> +<pool name="1"> + <minShare>2</minShare> + <weight>1</weight> + <schedulingMode>FIFO</schedulingMode> +</pool> +<pool name="2"> + <minShare>3</minShare> + <weight>1</weight> + <schedulingMode>FIFO</schedulingMode> +</pool> +<pool name="3"> +</pool> +</allocations> diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 8836c68ae6..6785787b7e 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -28,6 +28,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } } + test("basic checkpointing") { + val parCollection = sc.makeRDD(1 to 4) + val flatMappedRDD = parCollection.flatMap(x => 1 to x) + flatMappedRDD.checkpoint() + assert(flatMappedRDD.dependencies.head.rdd == parCollection) + val result = flatMappedRDD.collect() + assert(flatMappedRDD.dependencies.head.rdd != parCollection) + assert(flatMappedRDD.collect() === result) + } + test("RDDs with one-to-one dependencies") { testCheckpointing(_.map(x => x.toString)) testCheckpointing(_.flatMap(x => 1 to x)) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 46b74fe5ee..0866fb47b3 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -3,8 +3,10 @@ package spark import network.ConnectionManagerId import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Timeouts._ import org.scalatest.matchers.ShouldMatchers import org.scalatest.prop.Checkers +import org.scalatest.time.{Span, Millis} import org.scalacheck.Arbitrary._ import org.scalacheck.Gen import org.scalacheck.Prop._ @@ -16,7 +18,13 @@ import scala.collection.mutable.ArrayBuffer import SparkContext._ import storage.{GetBlock, BlockManagerWorker, StorageLevel} -class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext { + +class NotSerializableClass +class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} + + +class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter + with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" @@ -25,6 +33,24 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter System.clearProperty("spark.storage.memoryFraction") } + test("task throws not serializable exception") { + // Ensures that executors do not crash when an exn is not serializable. If executors crash, + // this test will hang. Correct behavior is that executors don't crash but fail tasks + // and the scheduler throws a SparkException. + + // numSlaves must be less than numPartitions + val numSlaves = 3 + val numPartitions = 10 + + sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test") + val data = sc.parallelize(1 to 100, numPartitions). + map(x => throw new NotSerializableExn(new NotSerializableClass)) + intercept[SparkException] { + data.count() + } + resetSparkContext() + } + test("local-cluster format") { sc = new SparkContext("local-cluster[2,1,512]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) @@ -153,7 +179,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter val blockManager = SparkEnv.get.blockManager blockManager.master.getLocations(blockId).foreach(id => { val bytes = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(id.ip, id.port)) + GetBlock(blockId), ConnectionManagerId(id.host, id.port)) val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList assert(deserialized === (1 to 100).toList) }) @@ -196,7 +222,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc = new SparkContext(clusterUrl, "test") val data = sc.parallelize(Seq(true, true), 2) assert(data.count === 2) // force executors to start - val masterId = SparkEnv.get.blockManager.blockManagerId assert(data.map(markNodeIfIdentity).collect.size === 2) assert(data.map(failOnMarkedIdentity).collect.size === 2) } @@ -252,6 +277,42 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter assert(data2.count === 2) } } + + test("unpersist RDDs") { + DistributedSuite.amMaster = true + sc = new SparkContext("local-cluster[3,1,512]", "test") + val data = sc.parallelize(Seq(true, false, false, false), 4) + data.persist(StorageLevel.MEMORY_ONLY_2) + data.count + assert(sc.persistentRdds.isEmpty === false) + data.unpersist() + assert(sc.persistentRdds.isEmpty === true) + + failAfter(Span(3000, Millis)) { + try { + while (! sc.getRDDStorageInfo.isEmpty) { + Thread.sleep(200) + } + } catch { + case _ => { Thread.sleep(10) } + // Do nothing. We might see exceptions because block manager + // is racing this thread to remove entries from the driver. + } + } + } + + test("job should fail if TaskResult exceeds Akka frame size") { + // We must use local-cluster mode since results are returned differently + // when running under LocalScheduler: + sc = new SparkContext("local-cluster[1,1,512]", "test") + val akkaFrameSize = + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + val rdd = sc.parallelize(Seq(1)).map{x => new Array[Byte](akkaFrameSize)} + val exception = intercept[SparkException] { + rdd.reduce((x, y) => x) + } + exception.getMessage should endWith("result exceeded Akka frame size") + } } object DistributedSuite { diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index 91b48c7456..e61ff7793d 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -7,6 +7,8 @@ import scala.io.Source import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.hadoop.io._ +import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec} + import SparkContext._ @@ -26,6 +28,28 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) } + test("text files (compressed)") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val normalDir = new File(tempDir, "output_normal").getAbsolutePath + val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath + val codec = new DefaultCodec() + + val data = sc.parallelize("a" * 10000, 1) + data.saveAsTextFile(normalDir) + data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec]) + + val normalFile = new File(normalDir, "part-00000") + val normalContent = sc.textFile(normalDir).collect + assert(normalContent === Array.fill(10000)("a")) + + val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) + val compressedContent = sc.textFile(compressedOutputDir).collect + assert(compressedContent === Array.fill(10000)("a")) + + assert(compressedFile.length < normalFile.length) + } + test("SequenceFiles") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() @@ -37,6 +61,28 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } + test("SequenceFile (compressed)") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val normalDir = new File(tempDir, "output_normal").getAbsolutePath + val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath + val codec = new DefaultCodec() + + val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x)) + data.saveAsSequenceFile(normalDir) + data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec])) + + val normalFile = new File(normalDir, "part-00000") + val normalContent = sc.sequenceFile[String, String](normalDir).collect + assert(normalContent === Array.fill(100)("abc", "abc")) + + val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) + val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect + assert(compressedContent === Array.fill(100)("abc", "abc")) + + assert(compressedFile.length < normalFile.length) + } + test("SequenceFile with writable key") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index d3dcd3bbeb..d306124fca 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -8,6 +8,7 @@ import java.util.*; import scala.Tuple2; import com.google.common.base.Charsets; +import org.apache.hadoop.io.compress.DefaultCodec; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; @@ -474,6 +475,19 @@ public class JavaAPISuite implements Serializable { } @Test + public void textFilesCompressed() throws IOException { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + rdd.saveAsTextFile(outputDir, DefaultCodec.class); + + // Try reading it in as a text file RDD + List<String> expected = Arrays.asList("1", "2", "3", "4"); + JavaRDD<String> readRDD = sc.textFile(outputDir); + Assert.assertEquals(expected, readRDD.collect()); + } + + @Test public void sequenceFile() { File tempDir = Files.createTempDir(); String outputDir = new File(tempDir, "output").getAbsolutePath(); @@ -620,6 +634,37 @@ public class JavaAPISuite implements Serializable { } @Test + public void hadoopFileCompressed() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); + List<Tuple2<Integer, String>> pairs = Arrays.asList( + new Tuple2<Integer, String>(1, "a"), + new Tuple2<Integer, String>(2, "aa"), + new Tuple2<Integer, String>(3, "aaa") + ); + JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() { + @Override + public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) { + return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, + DefaultCodec.class); + + JavaPairRDD<IntWritable, Text> output = sc.hadoopFile(outputDir, + SequenceFileInputFormat.class, IntWritable.class, Text.class); + + Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>, + String>() { + @Override + public String call(Tuple2<IntWritable, Text> x) { + return x.toString(); + } + }).collect().toString()); + } + + @Test public void zip() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() { @@ -633,6 +678,32 @@ public class JavaAPISuite implements Serializable { } @Test + public void zipPartitions() { + JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); + JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); + FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer> sizesFn = + new FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer>() { + @Override + public Iterable<Integer> call(Iterator<Integer> i, Iterator<String> s) { + int sizeI = 0; + int sizeS = 0; + while (i.hasNext()) { + sizeI += 1; + i.next(); + } + while (s.hasNext()) { + sizeS += 1; + s.next(); + } + return Arrays.asList(sizeI, sizeS); + } + }; + + JavaRDD<Integer> sizes = rdd1.zipPartitions(sizesFn, rdd2); + Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); + } + + @Test public void accumulators() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala index ff00dd05dd..76d5258b02 100644 --- a/core/src/test/scala/spark/LocalSparkContext.scala +++ b/core/src/test/scala/spark/LocalSparkContext.scala @@ -27,6 +27,7 @@ object LocalSparkContext { sc.stop() // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ @@ -38,4 +39,4 @@ object LocalSparkContext { } } -}
\ No newline at end of file +} diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 3abc584b6a..6e585e1c3a 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -8,7 +8,7 @@ import spark.storage.BlockManagerId import spark.util.AkkaUtils class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { - + test("compressSize") { assert(MapOutputTracker.compressSize(0L) === 0) assert(MapOutputTracker.compressSize(1L) === 1) @@ -45,13 +45,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), Array(compressedSize10000, compressedSize1000))) val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), - (BlockManagerId("b", "hostB", 1000), size10000))) + assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000, 0), size1000), + (BlockManagerId("b", "hostB", 1000, 0), size10000))) tracker.stop() } @@ -64,14 +64,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), Array(compressedSize10000, compressedSize1000, compressedSize1000))) // As if we had two simulatenous fetch failures - tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) - tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the @@ -80,16 +80,20 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { } test("remote fetch") { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0) + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + val masterTracker = new MapOutputTracker() masterTracker.trackerActor = actorSystem.actorOf( Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker") - - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", "localhost", 0) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0) val slaveTracker = new MapOutputTracker() slaveTracker.trackerActor = slaveSystem.actorFor( "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker") - + masterTracker.registerShuffle(10, 1) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) @@ -98,13 +102,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize1000 = MapOutputTracker.compressSize(1000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) - masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) + masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } diff --git a/core/src/test/scala/spark/PairRDDFunctionsSuite.scala b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala new file mode 100644 index 0000000000..682d2745bf --- /dev/null +++ b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala @@ -0,0 +1,287 @@ +package spark + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet + +import org.scalatest.FunSuite +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import com.google.common.io.Files + +import spark.rdd.ShuffledRDD +import spark.SparkContext._ + +class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { + test("groupByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with duplicates") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with negative key hash codes") { + val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesForMinus1 = groups.find(_._1 == -1).get._2 + assert(valuesForMinus1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with many output partitions") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) + val groups = pairs.groupByKey(10).collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("reduceByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("reduceByKey with collectAsMap") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_).collectAsMap() + assert(sums.size === 2) + assert(sums(1) === 7) + assert(sums(2) === 1) + } + + test("reduceByKey with many output partitons") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_, 10).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("reduceByKey with partitioner") { + val p = new Partitioner() { + def numPartitions = 2 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) + val sums = pairs.reduceByKey(_+_) + assert(sums.collect().toSet === Set((1, 4), (0, 1))) + assert(sums.partitioner === Some(p)) + // count the dependencies to make sure there is only 1 ShuffledRDD + val deps = new HashSet[RDD[_]]() + def visit(r: RDD[_]) { + for (dep <- r.dependencies) { + deps += dep.rdd + visit(dep.rdd) + } + } + visit(sums) + assert(deps.size === 2) // ShuffledRDD, ParallelCollection + } + + test("join") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (2, 'x')), + (2, (1, 'y')), + (2, (1, 'z')) + )) + } + + test("join all-to-all") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3))) + val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 6) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (1, 'y')), + (1, (2, 'x')), + (1, (2, 'y')), + (1, (3, 'x')), + (1, (3, 'y')) + )) + } + + test("leftOuterJoin") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.leftOuterJoin(rdd2).collect() + assert(joined.size === 5) + assert(joined.toSet === Set( + (1, (1, Some('x'))), + (1, (2, Some('x'))), + (2, (1, Some('y'))), + (2, (1, Some('z'))), + (3, (1, None)) + )) + } + + test("rightOuterJoin") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.rightOuterJoin(rdd2).collect() + assert(joined.size === 5) + assert(joined.toSet === Set( + (1, (Some(1), 'x')), + (1, (Some(2), 'x')), + (2, (Some(1), 'y')), + (2, (Some(1), 'z')), + (4, (None, 'w')) + )) + } + + test("join with no matches") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 0) + } + + test("join with many output partitions") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.join(rdd2, 10).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (2, 'x')), + (2, (1, 'y')), + (2, (1, 'z')) + )) + } + + test("groupWith") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.groupWith(rdd2).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))), + (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))), + (3, (ArrayBuffer(1), ArrayBuffer())), + (4, (ArrayBuffer(), ArrayBuffer('w'))) + )) + } + + test("zero-partition RDD") { + val emptyDir = Files.createTempDir() + val file = sc.textFile(emptyDir.getAbsolutePath) + assert(file.partitions.size == 0) + assert(file.collect().toList === Nil) + // Test that a shuffle on the file works, because this used to be a bug + assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) + } + + test("keys and values") { + val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) + assert(rdd.keys.collect().toList === List(1, 2)) + assert(rdd.values.collect().toList === List("a", "b")) + } + + test("default partitioner uses partition size") { + // specify 2000 partitions + val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) + // do a map, which loses the partitioner + val b = a.map(a => (a, (a * 2).toString)) + // then a group by, and see we didn't revert to 2 partitions + val c = b.groupByKey() + assert(c.partitions.size === 2000) + } + + test("default partitioner uses largest partitioner") { + val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2) + val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000) + val c = a.join(b) + assert(c.partitions.size === 2000) + } + + test("subtract") { + val a = sc.parallelize(Array(1, 2, 3), 2) + val b = sc.parallelize(Array(2, 3, 4), 4) + val c = a.subtract(b) + assert(c.collect().toSet === Set(1)) + assert(c.partitions.size === a.partitions.size) + } + + test("subtract with narrow dependency") { + // use a deterministic partitioner + val p = new Partitioner() { + def numPartitions = 5 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + // partitionBy so we have a narrow dependency + val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p) + // more partitions/no partitioner so a shuffle dependency + val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) + val c = a.subtract(b) + assert(c.collect().toSet === Set((1, "a"), (3, "c"))) + // Ideally we could keep the original partitioner... + assert(c.partitioner === None) + } + + test("subtractByKey") { + val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2) + val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4) + val c = a.subtractByKey(b) + assert(c.collect().toSet === Set((1, "a"), (1, "a"))) + assert(c.partitions.size === a.partitions.size) + } + + test("subtractByKey with narrow dependency") { + // use a deterministic partitioner + val p = new Partitioner() { + def numPartitions = 5 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + // partitionBy so we have a narrow dependency + val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p) + // more partitions/no partitioner so a shuffle dependency + val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) + val c = a.subtractByKey(b) + assert(c.collect().toSet === Set((1, "a"), (1, "a"))) + assert(c.partitioner.get === p) + } + + test("foldByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.foldByKey(0)(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("foldByKey with mutable result type") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache() + // Fold the values using in-place mutation + val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect() + assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1)))) + // Check that the mutable objects in the original RDD were not changed + assert(bufs.collect().toSet === Set( + (1, ArrayBuffer(1)), + (1, ArrayBuffer(2)), + (1, ArrayBuffer(3)), + (1, ArrayBuffer(1)), + (2, ArrayBuffer(1)))) + } +} diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index 60db759c25..99e433e3bd 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -1,13 +1,13 @@ package spark import org.scalatest.FunSuite - import scala.collection.mutable.ArrayBuffer - import SparkContext._ +import spark.util.StatCounter +import scala.math.abs + +class PartitioningSuite extends FunSuite with SharedSparkContext { -class PartitioningSuite extends FunSuite with LocalSparkContext { - test("HashPartitioner equality") { val p2 = new HashPartitioner(2) val p4 = new HashPartitioner(4) @@ -21,8 +21,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { } test("RangePartitioner equality") { - sc = new SparkContext("local", "test") - // Make an RDD where all the elements are the same so that the partition range bounds // are deterministically all the same. val rdd = sc.parallelize(Seq(1, 1, 1, 1)).map(x => (x, x)) @@ -50,7 +48,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { } test("HashPartitioner not equal to RangePartitioner") { - sc = new SparkContext("local", "test") val rdd = sc.parallelize(1 to 10).map(x => (x, x)) val rangeP2 = new RangePartitioner(2, rdd) val hashP2 = new HashPartitioner(2) @@ -61,8 +58,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { } test("partitioner preservation") { - sc = new SparkContext("local", "test") - val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x)) val grouped2 = rdd.groupByKey(2) @@ -101,7 +96,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { } test("partitioning Java arrays should fail") { - sc = new SparkContext("local", "test") val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x)) val arrPairs: RDD[(Array[Int], Int)] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) @@ -120,4 +114,20 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) } + + test("zero-length partitions should be correctly handled") { + // Create RDD with some consecutive empty partitions (including the "first" one) + val rdd: RDD[Double] = sc + .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) + .filter(_ >= 0.0) + + // Run the partitions, including the consecutive empty ones, through StatCounter + val stats: StatCounter = rdd.stats(); + assert(abs(6.0 - stats.sum) < 0.01); + assert(abs(6.0/2 - rdd.mean) < 0.01); + assert(abs(1.0 - rdd.variance) < 0.01); + assert(abs(1.0 - rdd.stdev) < 0.01); + + // Add other tests here for classes that should be able to handle empty partitions correctly + } } diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala index a6344edf8f..1c9ca50811 100644 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -3,10 +3,9 @@ package spark import org.scalatest.FunSuite import SparkContext._ -class PipedRDDSuite extends FunSuite with LocalSparkContext { - +class PipedRDDSuite extends FunSuite with SharedSparkContext { + test("basic pipe") { - sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) val piped = nums.pipe(Seq("cat")) @@ -19,8 +18,45 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { assert(c(3) === "4") } + test("advanced pipe") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val bl = sc.broadcast(List("0")) + + val piped = nums.pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, + (i:Int, f: String=> Unit) => f(i + "_")) + + val c = piped.collect() + + assert(c.size === 8) + assert(c(0) === "0") + assert(c(1) === "\u0001") + assert(c(2) === "1_") + assert(c(3) === "2_") + assert(c(4) === "0") + assert(c(5) === "\u0001") + assert(c(6) === "3_") + assert(c(7) === "4_") + + val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) + val d = nums1.groupBy(str=>str.split("\t")(0)). + pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, + (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect() + assert(d.size === 8) + assert(d(0) === "0") + assert(d(1) === "\u0001") + assert(d(2) === "b\t2_") + assert(d(3) === "b\t4_") + assert(d(4) === "0") + assert(d(5) === "\u0001") + assert(d(6) === "a\t1_") + assert(d(7) === "a\t3_") + } + test("pipe with env variable") { - sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) val c = piped.collect() @@ -30,7 +66,6 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { } test("pipe with non-zero exit status") { - sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) val piped = nums.pipe("cat nonexistent_file") intercept[SparkException] { diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 7fbdd44340..d8db69b1c9 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -2,13 +2,14 @@ package spark import scala.collection.mutable.HashMap import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.{Span, Millis} import spark.SparkContext._ -import spark.rdd.{CoalescedRDD, CoGroupedRDD, PartitionPruningRDD, ShuffledRDD} +import spark.rdd.{CoalescedRDD, CoGroupedRDD, EmptyRDD, PartitionPruningRDD, ShuffledRDD} -class RDDSuite extends FunSuite with LocalSparkContext { +class RDDSuite extends FunSuite with SharedSparkContext { test("basic operations") { - sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) @@ -44,7 +45,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("SparkContext.union") { - sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(sc.union(nums).collect().toList === List(1, 2, 3, 4)) assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) @@ -53,7 +53,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("aggregate") { - sc = new SparkContext("local", "test") val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) type StringMap = HashMap[String, Int] val emptyMap = new StringMap { @@ -73,27 +72,7 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } - test("basic checkpointing") { - import java.io.File - val checkpointDir = File.createTempFile("temp", "") - checkpointDir.delete() - - sc = new SparkContext("local", "test") - sc.setCheckpointDir(checkpointDir.toString) - val parCollection = sc.makeRDD(1 to 4) - val flatMappedRDD = parCollection.flatMap(x => 1 to x) - flatMappedRDD.checkpoint() - assert(flatMappedRDD.dependencies.head.rdd == parCollection) - val result = flatMappedRDD.collect() - Thread.sleep(1000) - assert(flatMappedRDD.dependencies.head.rdd != parCollection) - assert(flatMappedRDD.collect() === result) - - checkpointDir.deleteOnExit() - } - test("basic caching") { - sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(rdd.collect().toList === List(1, 2, 3, 4)) assert(rdd.collect().toList === List(1, 2, 3, 4)) @@ -101,7 +80,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("caching with failures") { - sc = new SparkContext("local", "test") val onlySplit = new Partition { override def index: Int = 0 } var shouldFail = true val rdd = new RDD[Int](sc, Nil) { @@ -123,38 +101,26 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(rdd.collect().toList === List(1, 2, 3, 4)) } - test("cogrouped RDDs") { - sc = new SparkContext("local", "test") - val rdd1 = sc.makeRDD(Array((1, "one"), (1, "another one"), (2, "two"), (3, "three")), 2) - val rdd2 = sc.makeRDD(Array((1, "one1"), (1, "another one1"), (2, "two1")), 2) - - // Use cogroup function - val cogrouped = rdd1.cogroup(rdd2).collectAsMap() - assert(cogrouped(1) === (Seq("one", "another one"), Seq("one1", "another one1"))) - assert(cogrouped(2) === (Seq("two"), Seq("two1"))) - assert(cogrouped(3) === (Seq("three"), Seq())) - - // Construct CoGroupedRDD directly, with map side combine enabled - val cogrouped1 = new CoGroupedRDD[Int]( - Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]), - new HashPartitioner(3), - true).collectAsMap() - assert(cogrouped1(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1"))) - assert(cogrouped1(2).toSeq === Seq(Seq("two"), Seq("two1"))) - assert(cogrouped1(3).toSeq === Seq(Seq("three"), Seq())) + test("empty RDD") { + val empty = new EmptyRDD[Int](sc) + assert(empty.count === 0) + assert(empty.collect().size === 0) - // Construct CoGroupedRDD directly, with map side combine disabled - val cogrouped2 = new CoGroupedRDD[Int]( - Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]), - new HashPartitioner(3), - false).collectAsMap() - assert(cogrouped2(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1"))) - assert(cogrouped2(2).toSeq === Seq(Seq("two"), Seq("two1"))) - assert(cogrouped2(3).toSeq === Seq(Seq("three"), Seq())) + val thrown = intercept[UnsupportedOperationException]{ + empty.reduce(_+_) + } + assert(thrown.getMessage.contains("empty")) + + val emptyKv = new EmptyRDD[(Int, Int)](sc) + val rdd = sc.parallelize(1 to 2, 2).map(x => (x, x)) + assert(rdd.join(emptyKv).collect().size === 0) + assert(rdd.rightOuterJoin(emptyKv).collect().size === 0) + assert(rdd.leftOuterJoin(emptyKv).collect().size === 2) + assert(rdd.cogroup(emptyKv).collect().size === 2) + assert(rdd.union(emptyKv).collect().size === 2) } - test("coalesced RDDs") { - sc = new SparkContext("local", "test") + test("cogrouped RDDs") { val data = sc.parallelize(1 to 10, 10) val coalesced1 = data.coalesce(2) @@ -192,7 +158,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("zipped RDDs") { - sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) val zipped = nums.zip(nums.map(_ + 1.0)) assert(zipped.glom().map(_.toList).collect().toList === @@ -204,7 +169,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("partition pruning") { - sc = new SparkContext("local", "test") val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) @@ -216,7 +180,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { test("mapWith") { import java.util.Random - sc = new SparkContext("local", "test") val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) val randoms = ones.mapWith( (index: Int) => new Random(index + 42)) @@ -235,7 +198,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { test("flatMapWith") { import java.util.Random - sc = new SparkContext("local", "test") val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) val randoms = ones.flatMapWith( (index: Int) => new Random(index + 42)) @@ -257,7 +219,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { test("filterWith") { import java.util.Random - sc = new SparkContext("local", "test") val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) val sample = ints.filterWith( (index: Int) => new Random(index + 42)) @@ -273,4 +234,21 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(sample.size === checkSample.size) for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) } + + test("top with predefined ordering") { + val nums = Array.range(1, 100000) + val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) + val topK = ints.top(5) + assert(topK.size === 5) + assert(topK.sorted === nums.sorted.takeRight(5)) + } + + test("top with custom ordering") { + val words = Vector("a", "b", "c", "d") + implicit val ord = implicitly[Ordering[String]].reverse + val rdd = sc.makeRDD(words, 2) + val topK = rdd.top(2) + assert(topK.size === 2) + assert(topK.sorted === Array("b", "a")) + } } diff --git a/core/src/test/scala/spark/SharedSparkContext.scala b/core/src/test/scala/spark/SharedSparkContext.scala new file mode 100644 index 0000000000..1da79f9824 --- /dev/null +++ b/core/src/test/scala/spark/SharedSparkContext.scala @@ -0,0 +1,25 @@ +package spark + +import org.scalatest.Suite +import org.scalatest.BeforeAndAfterAll + +/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */ +trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => + + @transient private var _sc: SparkContext = _ + + def sc: SparkContext = _sc + + override def beforeAll() { + _sc = new SparkContext("local", "test") + super.beforeAll() + } + + override def afterAll() { + if (_sc != null) { + LocalSparkContext.stop(_sc) + _sc = null + } + super.afterAll() + } +} diff --git a/core/src/test/scala/spark/ShuffleNettySuite.scala b/core/src/test/scala/spark/ShuffleNettySuite.scala new file mode 100644 index 0000000000..bfaffa953e --- /dev/null +++ b/core/src/test/scala/spark/ShuffleNettySuite.scala @@ -0,0 +1,17 @@ +package spark + +import org.scalatest.BeforeAndAfterAll + + +class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { + + // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. + + override def beforeAll(configMap: Map[String, Any]) { + System.setProperty("spark.shuffle.use.netty", "true") + } + + override def afterAll(configMap: Map[String, Any]) { + System.setProperty("spark.shuffle.use.netty", "false") + } +} diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 2b2a90defa..950218fa28 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -16,54 +16,9 @@ import spark.rdd.ShuffledRDD import spark.SparkContext._ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { - - test("groupByKey") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with duplicates") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with negative key hash codes") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesForMinus1 = groups.find(_._1 == -1).get._2 - assert(valuesForMinus1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with many output partitions") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) - val groups = pairs.groupByKey(10).collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - test("groupByKey with compression") { try { - System.setProperty("spark.blockManager.compress", "true") + System.setProperty("spark.shuffle.compress", "true") sc = new SparkContext("local", "test") val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4) val groups = pairs.groupByKey(4).collect() @@ -77,239 +32,100 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { } } - test("reduceByKey") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("reduceByKey with collectAsMap") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collectAsMap() - assert(sums.size === 2) - assert(sums(1) === 7) - assert(sums(2) === 1) - } + test("shuffle non-zero block size") { + sc = new SparkContext("local-cluster[2,1,512]", "test") + val NUM_BLOCKS = 3 - test("reduceByKey with many output partitons") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_, 10).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("reduceByKey with partitioner") { - sc = new SparkContext("local", "test") - val p = new Partitioner() { - def numPartitions = 2 - def getPartition(key: Any) = key.asInstanceOf[Int] + val a = sc.parallelize(1 to 10, 2) + val b = a.map { x => + (x, new ShuffleSuite.NonJavaSerializableClass(x * 2)) } - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) - val sums = pairs.reduceByKey(_+_) - assert(sums.collect().toSet === Set((1, 4), (0, 1))) - assert(sums.partitioner === Some(p)) - // count the dependencies to make sure there is only 1 ShuffledRDD - val deps = new HashSet[RDD[_]]() - def visit(r: RDD[_]) { - for (dep <- r.dependencies) { - deps += dep.rdd - visit(dep.rdd) - } + // If the Kryo serializer is not used correctly, the shuffle would fail because the + // default Java serializer cannot handle the non serializable class. + val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS), + classOf[spark.KryoSerializer].getName) + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + + assert(c.count === 10) + + // All blocks must have non-zero size + (0 until NUM_BLOCKS).foreach { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + assert(statuses.forall(s => s._2 > 0)) } - visit(sums) - assert(deps.size === 2) // ShuffledRDD, ParallelCollection } - test("join") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (2, 'x')), - (2, (1, 'y')), - (2, (1, 'z')) - )) + test("shuffle serializer") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + val a = sc.parallelize(1 to 10, 2) + val b = a.map { x => + (x, new ShuffleSuite.NonJavaSerializableClass(x * 2)) + } + // If the Kryo serializer is not used correctly, the shuffle would fail because the + // default Java serializer cannot handle the non serializable class. + val c = new ShuffledRDD(b, new HashPartitioner(3), classOf[spark.KryoSerializer].getName) + assert(c.count === 10) } - test("join all-to-all") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3))) - val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 6) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (1, 'y')), - (1, (2, 'x')), - (1, (2, 'y')), - (1, (3, 'x')), - (1, (3, 'y')) - )) - } + test("zero sized blocks") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") - test("leftOuterJoin") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.leftOuterJoin(rdd2).collect() - assert(joined.size === 5) - assert(joined.toSet === Set( - (1, (1, Some('x'))), - (1, (2, Some('x'))), - (2, (1, Some('y'))), - (2, (1, Some('z'))), - (3, (1, None)) - )) - } + // 10 partitions from 4 keys + val NUM_BLOCKS = 10 + val a = sc.parallelize(1 to 4, NUM_BLOCKS) + val b = a.map(x => (x, x*2)) - test("rightOuterJoin") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.rightOuterJoin(rdd2).collect() - assert(joined.size === 5) - assert(joined.toSet === Set( - (1, (Some(1), 'x')), - (1, (Some(2), 'x')), - (2, (Some(1), 'y')), - (2, (Some(1), 'z')), - (4, (None, 'w')) - )) - } + // NOTE: The default Java serializer doesn't create zero-sized blocks. + // So, use Kryo + val c = new ShuffledRDD(b, new HashPartitioner(10), classOf[spark.KryoSerializer].getName) - test("join with no matches") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 0) - } - - test("join with many output partitions") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.join(rdd2, 10).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (2, 'x')), - (2, (1, 'y')), - (2, (1, 'z')) - )) - } + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 4) - test("groupWith") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.groupWith(rdd2).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))), - (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))), - (3, (ArrayBuffer(1), ArrayBuffer())), - (4, (ArrayBuffer(), ArrayBuffer('w'))) - )) - } + val blockSizes = (0 until NUM_BLOCKS).flatMap { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + statuses.map(x => x._2) + } + val nonEmptyBlocks = blockSizes.filter(x => x > 0) - test("zero-partition RDD") { - sc = new SparkContext("local", "test") - val emptyDir = Files.createTempDir() - val file = sc.textFile(emptyDir.getAbsolutePath) - assert(file.partitions.size == 0) - assert(file.collect().toList === Nil) - // Test that a shuffle on the file works, because this used to be a bug - assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) + // We should have at most 4 non-zero sized partitions + assert(nonEmptyBlocks.size <= 4) } - test("keys and values") { - sc = new SparkContext("local", "test") - val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) - assert(rdd.keys.collect().toList === List(1, 2)) - assert(rdd.values.collect().toList === List("a", "b")) - } + test("zero sized blocks without kryo") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") - test("default partitioner uses partition size") { - sc = new SparkContext("local", "test") - // specify 2000 partitions - val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) - // do a map, which loses the partitioner - val b = a.map(a => (a, (a * 2).toString)) - // then a group by, and see we didn't revert to 2 partitions - val c = b.groupByKey() - assert(c.partitions.size === 2000) - } + // 10 partitions from 4 keys + val NUM_BLOCKS = 10 + val a = sc.parallelize(1 to 4, NUM_BLOCKS) + val b = a.map(x => (x, x*2)) - test("default partitioner uses largest partitioner") { - sc = new SparkContext("local", "test") - val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2) - val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000) - val c = a.join(b) - assert(c.partitions.size === 2000) - } + // NOTE: The default Java serializer should create zero-sized blocks + val c = new ShuffledRDD(b, new HashPartitioner(10)) - test("subtract") { - sc = new SparkContext("local", "test") - val a = sc.parallelize(Array(1, 2, 3), 2) - val b = sc.parallelize(Array(2, 3, 4), 4) - val c = a.subtract(b) - assert(c.collect().toSet === Set(1)) - assert(c.partitions.size === a.partitions.size) - } + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 4) - test("subtract with narrow dependency") { - sc = new SparkContext("local", "test") - // use a deterministic partitioner - val p = new Partitioner() { - def numPartitions = 5 - def getPartition(key: Any) = key.asInstanceOf[Int] + val blockSizes = (0 until NUM_BLOCKS).flatMap { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + statuses.map(x => x._2) } - // partitionBy so we have a narrow dependency - val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p) - // more partitions/no partitioner so a shuffle dependency - val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) - val c = a.subtract(b) - assert(c.collect().toSet === Set((1, "a"), (3, "c"))) - // Ideally we could keep the original partitioner... - assert(c.partitioner === None) - } - - test("subtractByKey") { - sc = new SparkContext("local", "test") - val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2) - val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4) - val c = a.subtractByKey(b) - assert(c.collect().toSet === Set((1, "a"), (1, "a"))) - assert(c.partitions.size === a.partitions.size) - } + val nonEmptyBlocks = blockSizes.filter(x => x > 0) - test("subtractByKey with narrow dependency") { - sc = new SparkContext("local", "test") - // use a deterministic partitioner - val p = new Partitioner() { - def numPartitions = 5 - def getPartition(key: Any) = key.asInstanceOf[Int] - } - // partitionBy so we have a narrow dependency - val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p) - // more partitions/no partitioner so a shuffle dependency - val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) - val c = a.subtractByKey(b) - assert(c.collect().toSet === Set((1, "a"), (1, "a"))) - assert(c.partitioner.get === p) + // We should have at most 4 non-zero sized partitions + assert(nonEmptyBlocks.size <= 4) } - } object ShuffleSuite { + def mergeCombineException(x: Int, y: Int): Int = { throw new SparkException("Exception for map-side combine.") x + y } + + class NonJavaSerializableClass(val value: Int) } diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala index 9f3aa6628d..c385965c35 100644 --- a/core/src/test/scala/spark/SizeEstimatorSuite.scala +++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala @@ -78,7 +78,6 @@ class SizeEstimatorSuite // Arrays containing nulls should just have one pointer per element expectResult(56)(SizeEstimator.estimate(new Array[String](10))) expectResult(56)(SizeEstimator.estimate(new Array[AnyRef](10))) - // For object arrays with non-null elements, each object should take one pointer plus // however many bytes that class takes. (Note that Array.fill calls the code in its // second parameter separately for each object, so we get distinct objects.) @@ -115,7 +114,6 @@ class SizeEstimatorSuite expectResult(48)(SizeEstimator.estimate(DummyString("a"))) expectResult(48)(SizeEstimator.estimate(DummyString("ab"))) expectResult(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) - resetOrClear("os.arch", arch) } diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala index 495f957e53..f7bf207c68 100644 --- a/core/src/test/scala/spark/SortingSuite.scala +++ b/core/src/test/scala/spark/SortingSuite.scala @@ -5,16 +5,14 @@ import org.scalatest.BeforeAndAfter import org.scalatest.matchers.ShouldMatchers import SparkContext._ -class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging { - +class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging { + test("sortByKey") { - sc = new SparkContext("local", "test") val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) - assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) + assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) } test("large array") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 2) @@ -24,7 +22,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w } test("large array with one split") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 2) @@ -32,9 +29,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w assert(sorted.partitions.size === 1) assert(sorted.collect() === pairArr.sortBy(_._1)) } - + test("large array with many partitions") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 2) @@ -42,9 +38,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w assert(sorted.partitions.size === 20) assert(sorted.collect() === pairArr.sortBy(_._1)) } - + test("sort descending") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 2) @@ -52,15 +47,13 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w } test("sort descending with one split") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 1) assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) } - + test("sort descending with many partitions") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 2) @@ -68,7 +61,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w } test("more partitions than elements") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 30) @@ -76,14 +68,12 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w } test("empty RDD") { - sc = new SparkContext("local", "test") val pairArr = new Array[(Int, Int)](0) val pairs = sc.parallelize(pairArr, 2) assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) } test("partition balancing") { - sc = new SparkContext("local", "test") val pairArr = (1 to 1000).map(x => (x, x)).toArray val sorted = sc.parallelize(pairArr, 4).sortByKey() assert(sorted.collect() === pairArr.sortBy(_._1)) @@ -99,7 +89,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w } test("partition balancing for descending sort") { - sc = new SparkContext("local", "test") val pairArr = (1 to 1000).map(x => (x, x)).toArray val sorted = sc.parallelize(pairArr, 4).sortByKey(false) assert(sorted.collect() === pairArr.sortBy(_._1).reverse) diff --git a/core/src/test/scala/spark/UnpersistSuite.scala b/core/src/test/scala/spark/UnpersistSuite.scala new file mode 100644 index 0000000000..94776e7572 --- /dev/null +++ b/core/src/test/scala/spark/UnpersistSuite.scala @@ -0,0 +1,30 @@ +package spark + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.{Span, Millis} +import spark.SparkContext._ + +class UnpersistSuite extends FunSuite with LocalSparkContext { + test("unpersist RDD") { + sc = new SparkContext("local", "test") + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + rdd.count + assert(sc.persistentRdds.isEmpty === false) + rdd.unpersist() + assert(sc.persistentRdds.isEmpty === true) + + failAfter(Span(3000, Millis)) { + try { + while (! sc.getRDDStorageInfo.isEmpty) { + Thread.sleep(200) + } + } catch { + case _ => { Thread.sleep(10) } + // Do nothing. We might see exceptions because block manager + // is racing this thread to remove entries from the driver. + } + } + assert(sc.getRDDStorageInfo.isEmpty === true) + } +} diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala index ed4701574f..4a113e16bf 100644 --- a/core/src/test/scala/spark/UtilsSuite.scala +++ b/core/src/test/scala/spark/UtilsSuite.scala @@ -27,24 +27,49 @@ class UtilsSuite extends FunSuite { assert(os.toByteArray.toList.equals(bytes.toList)) } - test("memoryStringToMb"){ - assert(Utils.memoryStringToMb("1") == 0) - assert(Utils.memoryStringToMb("1048575") == 0) - assert(Utils.memoryStringToMb("3145728") == 3) + test("memoryStringToMb") { + assert(Utils.memoryStringToMb("1") === 0) + assert(Utils.memoryStringToMb("1048575") === 0) + assert(Utils.memoryStringToMb("3145728") === 3) - assert(Utils.memoryStringToMb("1024k") == 1) - assert(Utils.memoryStringToMb("5000k") == 4) - assert(Utils.memoryStringToMb("4024k") == Utils.memoryStringToMb("4024K")) + assert(Utils.memoryStringToMb("1024k") === 1) + assert(Utils.memoryStringToMb("5000k") === 4) + assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K")) - assert(Utils.memoryStringToMb("1024m") == 1024) - assert(Utils.memoryStringToMb("5000m") == 5000) - assert(Utils.memoryStringToMb("4024m") == Utils.memoryStringToMb("4024M")) + assert(Utils.memoryStringToMb("1024m") === 1024) + assert(Utils.memoryStringToMb("5000m") === 5000) + assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M")) - assert(Utils.memoryStringToMb("2g") == 2048) - assert(Utils.memoryStringToMb("3g") == Utils.memoryStringToMb("3G")) + assert(Utils.memoryStringToMb("2g") === 2048) + assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G")) - assert(Utils.memoryStringToMb("2t") == 2097152) - assert(Utils.memoryStringToMb("3t") == Utils.memoryStringToMb("3T")) + assert(Utils.memoryStringToMb("2t") === 2097152) + assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T")) + } + + test("splitCommandString") { + assert(Utils.splitCommandString("") === Seq()) + assert(Utils.splitCommandString("a") === Seq("a")) + assert(Utils.splitCommandString("aaa") === Seq("aaa")) + assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c")) + assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c")) + assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c")) + assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d")) + assert(Utils.splitCommandString("'b c'") === Seq("b c")) + assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c")) + assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d")) + assert(Utils.splitCommandString("\"b c\"") === Seq("b c")) + assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e")) + assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d")) + assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c")) + assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c")) + assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c")) + assert(Utils.splitCommandString("'a'b") === Seq("ab")) + assert(Utils.splitCommandString("'a''b'") === Seq("ab")) + assert(Utils.splitCommandString("\"a\"b") === Seq("ab")) + assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab")) + assert(Utils.splitCommandString("''") === Seq("")) + assert(Utils.splitCommandString("\"\"") === Seq("")) } } diff --git a/core/src/test/scala/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/spark/ZippedPartitionsSuite.scala new file mode 100644 index 0000000000..96cb295f45 --- /dev/null +++ b/core/src/test/scala/spark/ZippedPartitionsSuite.scala @@ -0,0 +1,33 @@ +package spark + +import scala.collection.immutable.NumericRange + +import org.scalatest.FunSuite +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import SparkContext._ + + +object ZippedPartitionsSuite { + def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { + Iterator(i.toArray.size, s.toArray.size, d.toArray.size) + } +} + +class ZippedPartitionsSuite extends FunSuite with SharedSparkContext { + test("print sizes") { + val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) + val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) + val data3 = sc.makeRDD(Array(1.0, 2.0), 2) + + val zippedRDD = data1.zipPartitions(ZippedPartitionsSuite.procZippedData, data2, data3) + + val obtainedSizes = zippedRDD.collect() + val expectedSizes = Array(2, 3, 1, 2, 3, 1) + assert(obtainedSizes.size == 6) + assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2)) + } +} diff --git a/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala new file mode 100644 index 0000000000..6afb0fa9bc --- /dev/null +++ b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala @@ -0,0 +1,56 @@ +package spark + +import org.scalatest.{ BeforeAndAfter, FunSuite } +import spark.SparkContext._ +import spark.rdd.JdbcRDD +import java.sql._ + +class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + before { + Class.forName("org.apache.derby.jdbc.EmbeddedDriver") + val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") + try { + val create = conn.createStatement + create.execute(""" + CREATE TABLE FOO( + ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), + DATA INTEGER + )""") + create.close + val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") + (1 to 100).foreach { i => + insert.setInt(1, i * 2) + insert.executeUpdate + } + insert.close + } catch { + case e: SQLException if e.getSQLState == "X0Y32" => + // table exists + } finally { + conn.close + } + } + + test("basic functionality") { + sc = new SparkContext("local", "test") + val rdd = new JdbcRDD( + sc, + () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, + "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", + 1, 100, 3, + (r: ResultSet) => { r.getInt(1) } ).cache + + assert(rdd.count === 100) + assert(rdd.reduce(_+_) === 10100) + } + + after { + try { + DriverManager.getConnection("jdbc:derby:;shutdown=true") + } catch { + case se: SQLException if se.getSQLState == "XJ015" => + // normal shutdown + } + } +} diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala new file mode 100644 index 0000000000..8e1ad27e14 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala @@ -0,0 +1,250 @@ +package spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import spark._ +import spark.scheduler._ +import spark.scheduler.cluster._ +import scala.collection.mutable.ArrayBuffer + +import java.util.Properties + +class DummyTaskSetManager( + initPriority: Int, + initStageId: Int, + initNumTasks: Int, + clusterScheduler: ClusterScheduler, + taskSet: TaskSet) + extends ClusterTaskSetManager(clusterScheduler,taskSet) { + + parent = null + weight = 1 + minShare = 2 + runningTasks = 0 + priority = initPriority + stageId = initStageId + name = "TaskSet_"+stageId + override val numTasks = initNumTasks + tasksFinished = 0 + + override def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + override def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + override def addSchedulable(schedulable: Schedulable) { + } + + override def removeSchedulable(schedulable: Schedulable) { + } + + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def executorLost(executorId: String, host: String): Unit = { + } + + override def slaveOffer(execId: String, host: String, avaiableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + if (tasksFinished + runningTasks < numTasks) { + increaseRunningTasks(1) + return Some(new TaskDescription(0, execId, "task 0:0", null)) + } + return None + } + + override def checkSpeculatableTasks(): Boolean = { + return true + } + + def taskFinished() { + decreaseRunningTasks(1) + tasksFinished +=1 + if (tasksFinished == numTasks) { + parent.removeSchedulable(this) + } + } + + def abort() { + decreaseRunningTasks(runningTasks) + parent.removeSchedulable(this) + } +} + +class DummyTask(stageId: Int) extends Task[Int](stageId) +{ + def run(attemptId: Long): Int = { + return 0 + } +} + +class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { + + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): DummyTaskSetManager = { + new DummyTaskSetManager(priority, stage, numTasks, cs , taskSet) + } + + def resourceOffer(rootPool: Pool): Int = { + val taskSetQueue = rootPool.getSortedTaskSetQueue() + /* Just for Test*/ + for (manager <- taskSetQueue) { + logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks)) + } + for (taskSet <- taskSetQueue) { + taskSet.slaveOffer("execId_1", "hostname_1", 1) match { + case Some(task) => + return taskSet.stageId + case None => {} + } + } + -1 + } + + def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) { + assert(resourceOffer(rootPool) === expectedTaskSetId) + } + + test("FIFO Scheduler Test") { + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new DummyTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) + val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) + schedulableBuilder.buildPools() + + val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet) + val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet) + val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager0, null) + schedulableBuilder.addTaskSetManager(taskSetManager1, null) + schedulableBuilder.addTaskSetManager(taskSetManager2, null) + + checkTaskSetId(rootPool, 0) + resourceOffer(rootPool) + checkTaskSetId(rootPool, 1) + resourceOffer(rootPool) + taskSetManager1.abort() + checkTaskSetId(rootPool, 2) + } + + test("Fair Scheduler Test") { + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new DummyTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.fairscheduler.allocation.file", xmlPath) + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool) + schedulableBuilder.buildPools() + + assert(rootPool.getSchedulableByName("default") != null) + assert(rootPool.getSchedulableByName("1") != null) + assert(rootPool.getSchedulableByName("2") != null) + assert(rootPool.getSchedulableByName("3") != null) + assert(rootPool.getSchedulableByName("1").minShare === 2) + assert(rootPool.getSchedulableByName("1").weight === 1) + assert(rootPool.getSchedulableByName("2").minShare === 3) + assert(rootPool.getSchedulableByName("2").weight === 1) + assert(rootPool.getSchedulableByName("3").minShare === 2) + assert(rootPool.getSchedulableByName("3").weight === 1) + + val properties1 = new Properties() + properties1.setProperty("spark.scheduler.cluster.fair.pool","1") + val properties2 = new Properties() + properties2.setProperty("spark.scheduler.cluster.fair.pool","2") + + val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet) + val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet) + val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) + + val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet) + val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) + schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 1) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 4) + + taskSetManager12.taskFinished() + assert(rootPool.getSchedulableByName("1").runningTasks === 3) + taskSetManager24.abort() + assert(rootPool.getSchedulableByName("2").runningTasks === 2) + } + + test("Nested Pool Test") { + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new DummyTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) + val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) + rootPool.addSchedulable(pool0) + rootPool.addSchedulable(pool1) + + val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2) + val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1) + pool0.addSchedulable(pool00) + pool0.addSchedulable(pool01) + + val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2) + val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) + pool1.addSchedulable(pool10) + pool1.addSchedulable(pool11) + + val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet) + val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet) + pool00.addSchedulable(taskSetManager000) + pool00.addSchedulable(taskSetManager001) + + val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet) + val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet) + pool01.addSchedulable(taskSetManager010) + pool01.addSchedulable(taskSetManager011) + + val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet) + val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet) + pool10.addSchedulable(taskSetManager100) + pool10.addSchedulable(taskSetManager101) + + val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet) + val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet) + pool11.addSchedulable(taskSetManager110) + pool11.addSchedulable(taskSetManager111) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 6) + checkTaskSetId(rootPool, 2) + } +} diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 6da58a0f6e..30e6fef950 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -44,7 +44,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont override def submitTasks(taskSet: TaskSet) = { // normally done by TaskSetManager taskSet.tasks.foreach(_.generation = mapOutputTracker.getGeneration) - taskSets += taskSet + taskSets += taskSet } override def setListener(listener: TaskSchedulerListener) = {} override def defaultParallelism() = 2 @@ -164,7 +164,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont } } } - + /** Sends the rdd to the scheduler for scheduling. */ private def submit( rdd: RDD[_], @@ -174,7 +174,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont listener: JobListener = listener) { runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener)) } - + /** Sends TaskSetFailed to the scheduler. */ private def failed(taskSet: TaskSet, message: String) { runEvent(TaskSetFailed(taskSet, message)) @@ -209,11 +209,11 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener)) assert(results === Map(0 -> 42)) } - + test("run trivial job w/ dependency") { val baseRdd = makeRdd(1, Nil) val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) - submit(finalRdd, Array(0)) + submit(finalRdd, Array(0)) complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) } @@ -250,7 +250,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) } - + test("run trivial shuffle with fetch failure") { val shuffleMapRdd = makeRdd(2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) @@ -271,7 +271,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.ip) === Array("hostA", "hostB")) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) } @@ -385,12 +385,12 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assert(results === Map(0 -> 42)) } - /** Assert that the supplied TaskSet has exactly the given preferredLocations. */ + /** Assert that the supplied TaskSet has exactly the given preferredLocations. Note, converts taskSet's locations to host only. */ private def assertLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) { assert(locations.size === taskSet.tasks.size) for ((expectLocs, taskLocs) <- taskSet.tasks.map(_.preferredLocations).zip(locations)) { - assert(expectLocs === taskLocs) + assert(expectLocs.map(loc => spark.Utils.parseHostPort(loc)._1) === taskLocs) } } @@ -398,6 +398,6 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2)) private def makeBlockManagerId(host: String): BlockManagerId = - BlockManagerId("exec-" + host, host, 12345) + BlockManagerId("exec-" + host, host, 12345, 0) } diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala new file mode 100644 index 0000000000..699901f1a1 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala @@ -0,0 +1,104 @@ +package spark.scheduler + +import java.util.Properties +import java.util.concurrent.LinkedBlockingQueue +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import scala.collection.mutable +import spark._ +import spark.SparkContext._ + + +class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + + test("inner method") { + sc = new SparkContext("local", "joblogger") + val joblogger = new JobLogger { + def createLogWriterTest(jobID: Int) = createLogWriter(jobID) + def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID) + def getRddNameTest(rdd: RDD[_]) = getRddName(rdd) + def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage) + } + type MyRDD = RDD[(Int, Int)] + def makeRdd( + numPartitions: Int, + dependencies: List[Dependency[_]] + ): MyRDD = { + val maxPartition = numPartitions - 1 + return new MyRDD(sc, dependencies) { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + override def getPartitions = (0 to maxPartition).map(i => new Partition { + override def index = i + }).toArray + } + } + val jobID = 5 + val parentRdd = makeRdd(4, Nil) + val shuffleDep = new ShuffleDependency(parentRdd, null) + val rootRdd = makeRdd(4, List(shuffleDep)) + val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID) + val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID) + + joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4)) + joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) + parentRdd.setName("MyRDD") + joblogger.getRddNameTest(parentRdd) should be ("MyRDD") + joblogger.createLogWriterTest(jobID) + joblogger.getJobIDtoPrintWriter.size should be (1) + joblogger.buildJobDepTest(jobID, rootStage) + joblogger.getJobIDToStages.get(jobID).get.size should be (2) + joblogger.getStageIDToJobID.get(0) should be (Some(jobID)) + joblogger.getStageIDToJobID.get(1) should be (Some(jobID)) + joblogger.closeLogWriterTest(jobID) + joblogger.getStageIDToJobID.size should be (0) + joblogger.getJobIDToStages.size should be (0) + joblogger.getJobIDtoPrintWriter.size should be (0) + } + + test("inner variables") { + sc = new SparkContext("local[4]", "joblogger") + val joblogger = new JobLogger { + override protected def closeLogWriter(jobID: Int) = + getJobIDtoPrintWriter.get(jobID).foreach { fileWriter => + fileWriter.close() + } + } + sc.addSparkListener(joblogger) + val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } + rdd.reduceByKey(_+_).collect() + + joblogger.getLogDir should be ("/tmp/spark") + joblogger.getJobIDtoPrintWriter.size should be (1) + joblogger.getStageIDToJobID.size should be (2) + joblogger.getStageIDToJobID.get(0) should be (Some(0)) + joblogger.getStageIDToJobID.get(1) should be (Some(0)) + joblogger.getJobIDToStages.size should be (1) + } + + + test("interface functions") { + sc = new SparkContext("local[4]", "joblogger") + val joblogger = new JobLogger { + var onTaskEndCount = 0 + var onJobEndCount = 0 + var onJobStartCount = 0 + var onStageCompletedCount = 0 + var onStageSubmittedCount = 0 + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1 + override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1 + override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1 + override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1 + } + sc.addSparkListener(joblogger) + val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } + rdd.reduceByKey(_+_).collect() + + joblogger.onJobStartCount should be (1) + joblogger.onJobEndCount should be (1) + joblogger.onTaskEndCount should be (8) + joblogger.onStageSubmittedCount should be (2) + joblogger.onStageCompletedCount should be (2) + } +} diff --git a/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala new file mode 100644 index 0000000000..8bd813fd14 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala @@ -0,0 +1,206 @@ +package spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import spark._ +import spark.scheduler._ +import spark.scheduler.cluster._ +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ConcurrentMap, HashMap} +import java.util.concurrent.Semaphore +import java.util.concurrent.CountDownLatch +import java.util.Properties + +class Lock() { + var finished = false + def jobWait() = { + synchronized { + while(!finished) { + this.wait() + } + } + } + + def jobFinished() = { + synchronized { + finished = true + this.notifyAll() + } + } +} + +object TaskThreadInfo { + val threadToLock = HashMap[Int, Lock]() + val threadToRunning = HashMap[Int, Boolean]() + val threadToStarted = HashMap[Int, CountDownLatch]() +} + +/* + * 1. each thread contains one job. + * 2. each job contains one stage. + * 3. each stage only contains one task. + * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure + * it will get cpu core resource, and will wait to finished after user manually + * release "Lock" and then cluster will contain another free cpu cores. + * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue, + * thus it will be scheduled later when cluster has free cpu cores. + */ +class LocalSchedulerSuite extends FunSuite with LocalSparkContext { + + def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { + + TaskThreadInfo.threadToRunning(threadIndex) = false + val nums = sc.parallelize(threadIndex to threadIndex, 1) + TaskThreadInfo.threadToLock(threadIndex) = new Lock() + TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1) + new Thread { + if (poolName != null) { + sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName) + } + override def run() { + val ans = nums.map(number => { + TaskThreadInfo.threadToRunning(number) = true + TaskThreadInfo.threadToStarted(number).countDown() + TaskThreadInfo.threadToLock(number).jobWait() + TaskThreadInfo.threadToRunning(number) = false + number + }).collect() + assert(ans.toList === List(threadIndex)) + sem.release() + } + }.start() + } + + test("Local FIFO scheduler end-to-end test") { + System.setProperty("spark.cluster.schedulingmode", "FIFO") + sc = new SparkContext("local[4]", "test") + val sem = new Semaphore(0) + + createThread(1,null,sc,sem) + TaskThreadInfo.threadToStarted(1).await() + createThread(2,null,sc,sem) + TaskThreadInfo.threadToStarted(2).await() + createThread(3,null,sc,sem) + TaskThreadInfo.threadToStarted(3).await() + createThread(4,null,sc,sem) + TaskThreadInfo.threadToStarted(4).await() + // thread 5 and 6 (stage pending)must meet following two points + // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager + // queue before executing TaskThreadInfo.threadToLock(1).jobFinished() + // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6 + // So I just use "sleep" 1s here for each thread. + // TODO: any better solution? + createThread(5,null,sc,sem) + Thread.sleep(1000) + createThread(6,null,sc,sem) + Thread.sleep(1000) + + assert(TaskThreadInfo.threadToRunning(1) === true) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === true) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === false) + assert(TaskThreadInfo.threadToRunning(6) === false) + + TaskThreadInfo.threadToLock(1).jobFinished() + TaskThreadInfo.threadToStarted(5).await() + + assert(TaskThreadInfo.threadToRunning(1) === false) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === true) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === true) + assert(TaskThreadInfo.threadToRunning(6) === false) + + TaskThreadInfo.threadToLock(3).jobFinished() + TaskThreadInfo.threadToStarted(6).await() + + assert(TaskThreadInfo.threadToRunning(1) === false) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === false) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === true) + assert(TaskThreadInfo.threadToRunning(6) === true) + + TaskThreadInfo.threadToLock(2).jobFinished() + TaskThreadInfo.threadToLock(4).jobFinished() + TaskThreadInfo.threadToLock(5).jobFinished() + TaskThreadInfo.threadToLock(6).jobFinished() + sem.acquire(6) + } + + test("Local fair scheduler end-to-end test") { + sc = new SparkContext("local[8]", "LocalSchedulerSuite") + val sem = new Semaphore(0) + System.setProperty("spark.cluster.schedulingmode", "FAIR") + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.fairscheduler.allocation.file", xmlPath) + + createThread(10,"1",sc,sem) + TaskThreadInfo.threadToStarted(10).await() + createThread(20,"2",sc,sem) + TaskThreadInfo.threadToStarted(20).await() + createThread(30,"3",sc,sem) + TaskThreadInfo.threadToStarted(30).await() + + assert(TaskThreadInfo.threadToRunning(10) === true) + assert(TaskThreadInfo.threadToRunning(20) === true) + assert(TaskThreadInfo.threadToRunning(30) === true) + + createThread(11,"1",sc,sem) + TaskThreadInfo.threadToStarted(11).await() + createThread(21,"2",sc,sem) + TaskThreadInfo.threadToStarted(21).await() + createThread(31,"3",sc,sem) + TaskThreadInfo.threadToStarted(31).await() + + assert(TaskThreadInfo.threadToRunning(11) === true) + assert(TaskThreadInfo.threadToRunning(21) === true) + assert(TaskThreadInfo.threadToRunning(31) === true) + + createThread(12,"1",sc,sem) + TaskThreadInfo.threadToStarted(12).await() + createThread(22,"2",sc,sem) + TaskThreadInfo.threadToStarted(22).await() + createThread(32,"3",sc,sem) + + assert(TaskThreadInfo.threadToRunning(12) === true) + assert(TaskThreadInfo.threadToRunning(22) === true) + assert(TaskThreadInfo.threadToRunning(32) === false) + + TaskThreadInfo.threadToLock(10).jobFinished() + TaskThreadInfo.threadToStarted(32).await() + + assert(TaskThreadInfo.threadToRunning(32) === true) + + //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager + // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished. + //2. priority of 23 and 33 will be meaningless as using fair scheduler here. + createThread(23,"2",sc,sem) + createThread(33,"3",sc,sem) + Thread.sleep(1000) + + TaskThreadInfo.threadToLock(11).jobFinished() + TaskThreadInfo.threadToStarted(23).await() + + assert(TaskThreadInfo.threadToRunning(23) === true) + assert(TaskThreadInfo.threadToRunning(33) === false) + + TaskThreadInfo.threadToLock(12).jobFinished() + TaskThreadInfo.threadToStarted(33).await() + + assert(TaskThreadInfo.threadToRunning(33) === true) + + TaskThreadInfo.threadToLock(20).jobFinished() + TaskThreadInfo.threadToLock(21).jobFinished() + TaskThreadInfo.threadToLock(22).jobFinished() + TaskThreadInfo.threadToLock(23).jobFinished() + TaskThreadInfo.threadToLock(30).jobFinished() + TaskThreadInfo.threadToLock(31).jobFinished() + TaskThreadInfo.threadToLock(32).jobFinished() + TaskThreadInfo.threadToLock(33).jobFinished() + + sem.acquire(11) + } +} diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala index 2f5af10e69..48aa67c543 100644 --- a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala @@ -57,7 +57,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc taskMetrics.shuffleReadMetrics should be ('defined) val sm = taskMetrics.shuffleReadMetrics.get sm.totalBlocksFetched should be > (0) - sm.shuffleReadMillis should be > (0l) sm.localBlocksFetched should be > (0) sm.remoteBlocksFetched should be (0) sm.remoteBytesRead should be (0l) @@ -78,7 +77,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc class SaveStageInfo extends SparkListener { val stageInfos = mutable.Buffer[StageInfo]() - def onStageCompleted(stage: StageCompleted) { + override def onStageCompleted(stage: StageCompleted) { stageInfos += stage.stageInfo } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index b8c0f6fb76..b9d5f9668e 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -15,8 +15,10 @@ import org.scalatest.time.SpanSugar._ import spark.JavaSerializer import spark.KryoSerializer import spark.SizeEstimator +import spark.util.AkkaUtils import spark.util.ByteBufferInputStream + class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { var store: BlockManager = null var store2: BlockManager = null @@ -31,7 +33,11 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val serializer = new KryoSerializer before { - actorSystem = ActorSystem("test") + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) + this.actorSystem = actorSystem + System.setProperty("spark.driver.port", boundPort.toString) + System.setProperty("spark.hostPort", "localhost:" + boundPort) + master = new BlockManagerMaster( actorSystem.actorOf(Props(new spark.storage.BlockManagerMasterActor(true)))) @@ -41,9 +47,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true") val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() + // Set some value ... + System.setProperty("spark.hostPort", spark.Utils.localHostName() + ":" + 1111) } after { + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + if (store != null) { store.stop() store = null @@ -88,9 +99,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("BlockManagerId object caching") { - val id1 = BlockManagerId("e1", "XXX", 1) - val id2 = BlockManagerId("e1", "XXX", 1) // this should return the same object as id1 - val id3 = BlockManagerId("e1", "XXX", 2) // this should return a different object + val id1 = BlockManagerId("e1", "XXX", 1, 0) + val id2 = BlockManagerId("e1", "XXX", 1, 0) // this should return the same object as id1 + val id3 = BlockManagerId("e1", "XXX", 2, 0) // this should return a different object assert(id2 === id1, "id2 is not same as id1") assert(id2.eq(id1), "id2 is not the same object as id1") assert(id3 != id1, "id3 is same as id1") @@ -113,7 +124,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, false) + store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) // Checking whether blocks are in memory assert(store.getSingle("a1") != None, "a1 was not in store") @@ -159,7 +170,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY) store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, false) + store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) // Checking whether blocks are in memory and memory size val memStatus = master.getMemoryStatus.head._2 @@ -198,6 +209,39 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } } + test("removing rdd") { + store = new BlockManager("<driver>", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + // Putting a1, a2 and a3 in memory. + store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY) + master.removeRdd(0, blocking = false) + + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("rdd_0_0") should be (None) + master.getLocations("rdd_0_0") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("rdd_0_1") should be (None) + master.getLocations("rdd_0_1") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("nonrddblock") should not be (None) + master.getLocations("nonrddblock") should have size (1) + } + + store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) + master.removeRdd(0, blocking = true) + store.getSingle("rdd_0_0") should be (None) + master.getLocations("rdd_0_0") should have size 0 + store.getSingle("rdd_0_1") should be (None) + master.getLocations("rdd_0_1") should have size 0 + } + test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) store = new BlockManager("<driver>", actorSystem, master, serializer, 2000) @@ -226,7 +270,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) store.waitForAsyncReregister() assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") @@ -244,7 +288,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT master.removeExecutor(store.blockManagerId.executorId) val t1 = new Thread { override def run() { - store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true) + store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } } val t2 = new Thread { @@ -454,9 +498,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, true) - store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, true) - store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, true) + store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(store.get("list2") != None, "list2 was not in store") assert(store.get("list2").get.size == 2) assert(store.get("list3") != None, "list3 was not in store") @@ -465,7 +509,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.get("list2") != None, "list2 was not in store") assert(store.get("list2").get.size == 2) // At this point list2 was gotten last, so LRU will getSingle rid of list3 - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, true) + store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(store.get("list1") != None, "list1 was not in store") assert(store.get("list1").get.size == 2) assert(store.get("list2") != None, "list2 was not in store") @@ -480,9 +524,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val list3 = List(new Array[Byte](200), new Array[Byte](200)) val list4 = List(new Array[Byte](200), new Array[Byte](200)) // First store list1 and list2, both in memory, and list3, on disk only - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, true) - store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, true) - store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, true) + store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) // At this point LRU should not kick in because list3 is only on disk assert(store.get("list1") != None, "list2 was not in store") assert(store.get("list1").get.size === 2) @@ -497,7 +541,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.get("list3") != None, "list1 was not in store") assert(store.get("list3").get.size === 2) // Now let's add in list4, which uses both disk and memory; list1 should drop out - store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, true) + store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) assert(store.get("list1") === None, "list1 was in store") assert(store.get("list2") != None, "list3 was not in store") assert(store.get("list2").get.size === 2) |