diff options
28 files changed, 985 insertions, 334 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cc44a4c7dd..a12f8860b9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -248,7 +248,6 @@ class SparkContext( taskScheduler.start() @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler) - dagScheduler.start() ui.start() @@ -282,6 +281,12 @@ class SparkContext( override protected def childValue(parent: Properties): Properties = new Properties(parent) } + private[spark] def getLocalProperties(): Properties = localProperties.get() + + private[spark] def setLocalProperties(props: Properties) { + localProperties.set(props) + } + def initLocalProperties() { localProperties.set(new Properties()) } @@ -303,7 +308,7 @@ class SparkContext( /** Set a human readable description of the current job. */ @deprecated("use setJobGroup", "0.8.1") def setJobDescription(value: String) { - setJobGroup("", value) + setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index c29a30184a..fc1537f796 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.UserGroupInformation -import org.apache.spark.SparkException +import org.apache.spark.{SparkContext, SparkException} /** * Contains util methods to interact with Hadoop from Spark. @@ -34,10 +34,21 @@ class SparkHadoopUtil { UserGroupInformation.setConfiguration(conf) def runAsUser(user: String)(func: () => Unit) { - val ugi = UserGroupInformation.createRemoteUser(user) - ugi.doAs(new PrivilegedExceptionAction[Unit] { - def run: Unit = func() - }) + // if we are already running as the user intended there is no reason to do the doAs. It + // will actually break secure HDFS access as it doesn't fill in the credentials. Also if + // the user is UNKNOWN then we shouldn't be creating a remote unknown user + // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only + // in SparkContext. + val currentUser = Option(System.getProperty("user.name")). + getOrElse(SparkContext.SPARK_UNKNOWN_USER) + if (user != SparkContext.SPARK_UNKNOWN_USER && currentUser != user) { + val ugi = UserGroupInformation.createRemoteUser(user) + ugi.doAs(new PrivilegedExceptionAction[Unit] { + def run: Unit = func() + }) + } else { + func() + } } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 32901a508f..47e958b5e6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -132,6 +132,8 @@ class HadoopRDD[K, V]( override def getPartitions: Array[Partition] = { val jobConf = getJobConf() + // add the credentials here as this can be called before SparkContext initialized + SparkHadoopUtil.get.addCredentials(jobConf) val inputFormat = getInputFormat(jobConf) if (inputFormat.isInstanceOf[Configurable]) { inputFormat.asInstanceOf[Configurable].setConf(jobConf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ab7b3a2e24..7b4fc6b9be 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -19,9 +19,11 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties -import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global +import akka.actor._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.reflect.ClassTag @@ -66,12 +68,12 @@ class DAGScheduler( // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { - eventQueue.put(BeginEvent(task, taskInfo)) + eventProcessActor ! BeginEvent(task, taskInfo) } // Called to report that a task has completed and results are being fetched remotely. def taskGettingResult(task: Task[_], taskInfo: TaskInfo) { - eventQueue.put(GettingResultEvent(task, taskInfo)) + eventProcessActor ! GettingResultEvent(task, taskInfo) } // Called by TaskScheduler to report task completions or failures. @@ -82,35 +84,58 @@ class DAGScheduler( accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics) { - eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) + eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics) } // Called by TaskScheduler when an executor fails. def executorLost(execId: String) { - eventQueue.put(ExecutorLost(execId)) + eventProcessActor ! ExecutorLost(execId) } // Called by TaskScheduler when a host is added def executorGained(execId: String, host: String) { - eventQueue.put(ExecutorGained(execId, host)) + eventProcessActor ! ExecutorGained(execId, host) } // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. def taskSetFailed(taskSet: TaskSet, reason: String) { - eventQueue.put(TaskSetFailed(taskSet, reason)) + eventProcessActor ! TaskSetFailed(taskSet, reason) } // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in - val RESUBMIT_TIMEOUT = 50L + val RESUBMIT_TIMEOUT = 50.milliseconds // The time, in millis, to wake up between polls of the completion queue in order to potentially // resubmit failed stages val POLL_TIMEOUT = 10L - private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] + private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor { + override def preStart() { + context.system.scheduler.schedule(RESUBMIT_TIMEOUT, RESUBMIT_TIMEOUT) { + if (failed.size > 0) { + resubmitFailedStages() + } + } + } + + /** + * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure + * events and responds by launching tasks. This runs in a dedicated thread and receives events + * via the eventQueue. + */ + def receive = { + case event: DAGSchedulerEvent => + logDebug("Got event of type " + event.getClass.getName) + + if (!processEvent(event)) + submitWaitingStages() + else + context.stop(self) + } + })) private[scheduler] val nextJobId = new AtomicInteger(0) @@ -151,16 +176,6 @@ class DAGScheduler( val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup) - // Start a thread to run the DAGScheduler event loop - def start() { - new Thread("DAGScheduler") { - setDaemon(true) - override def run() { - DAGScheduler.this.run() - } - }.start() - } - def addSparkListener(listener: SparkListener) { listenerBus.addListener(listener) } @@ -302,8 +317,7 @@ class DAGScheduler( assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) - eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, - waiter, properties)) + eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) waiter } @@ -338,8 +352,7 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() - eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, - listener, properties)) + eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties) listener.awaitResult() // Will throw an exception if the job fails } @@ -348,19 +361,19 @@ class DAGScheduler( */ def cancelJob(jobId: Int) { logInfo("Asked to cancel job " + jobId) - eventQueue.put(JobCancelled(jobId)) + eventProcessActor ! JobCancelled(jobId) } def cancelJobGroup(groupId: String) { logInfo("Asked to cancel job group " + groupId) - eventQueue.put(JobGroupCancelled(groupId)) + eventProcessActor ! JobGroupCancelled(groupId) } /** * Cancel all jobs that are running or waiting in the queue. */ def cancelAllJobs() { - eventQueue.put(AllJobsCancelled) + eventProcessActor ! AllJobsCancelled } /** @@ -475,42 +488,6 @@ class DAGScheduler( } } - - /** - * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure - * events and responds by launching tasks. This runs in a dedicated thread and receives events - * via the eventQueue. - */ - private def run() { - SparkEnv.set(env) - - while (true) { - val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS) - if (event != null) { - logDebug("Got event of type " + event.getClass.getName) - } - this.synchronized { // needed in case other threads makes calls into methods of this class - if (event != null) { - if (processEvent(event)) { - return - } - } - - val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability - // Periodically resubmit failed stages if some map output fetches have failed and we have - // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails, - // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at - // the same time, so we want to make sure we've identified all the reduce tasks that depend - // on the failed node. - if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { - resubmitFailedStages() - } else { - submitWaitingStages() - } - } - } - } - /** * Run a job on an RDD locally, assuming it has only a single partition and no dependencies. * We run the operation in a separate thread just in case it takes a bunch of time, so that we @@ -879,7 +856,7 @@ class DAGScheduler( // If the RDD has narrow dependencies, pick the first partition of the first narrow dep // that has any placement preferences. Ideally we would choose based on transfer sizes, // but this will do for now. - rdd.dependencies.foreach(_ match { + rdd.dependencies.foreach { case n: NarrowDependency[_] => for (inPart <- n.getParents(partition)) { val locs = getPreferredLocs(n.rdd, inPart) @@ -887,7 +864,7 @@ class DAGScheduler( return locs } case _ => - }) + } Nil } @@ -910,7 +887,7 @@ class DAGScheduler( } def stop() { - eventQueue.put(StopDAGScheduler) + eventProcessActor ! StopDAGScheduler metadataCleaner.cancel() taskSched.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 85033958ef..2d8a0a62c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -25,6 +25,9 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global + import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler._ @@ -119,21 +122,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.start() if (System.getProperty("spark.speculation", "false").toBoolean) { - new Thread("ClusterScheduler speculation check") { - setDaemon(true) - - override def run() { - logInfo("Starting speculative execution thread") - while (true) { - try { - Thread.sleep(SPECULATION_INTERVAL) - } catch { - case e: InterruptedException => {} - } - checkSpeculatableTasks() - } - } - }.start() + logInfo("Starting speculative execution thread") + + sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, + SPECULATION_INTERVAL milliseconds) { + checkSpeculatableTasks() + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index d78bdbaa7a..6b91935400 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -31,6 +31,10 @@ private[spark] class SimrSchedulerBackend( val tmpPath = new Path(driverFilePath + "_tmp") val filePath = new Path(driverFilePath) + val uiFilePath = driverFilePath + "_ui" + val tmpUiPath = new Path(uiFilePath + "_tmp") + val uiPath = new Path(uiFilePath) + val maxCores = System.getProperty("spark.simr.executor.cores", "1").toInt override def start() { @@ -45,6 +49,8 @@ private[spark] class SimrSchedulerBackend( logInfo("Writing to HDFS file: " + driverFilePath) logInfo("Writing Akka address: " + driverUrl) + logInfo("Writing to HDFS file: " + uiFilePath) + logInfo("Writing Spark UI Address: " + sc.ui.appUIAddress) // Create temporary file to prevent race condition where executors get empty driverUrl file val temp = fs.create(tmpPath, true) @@ -54,6 +60,12 @@ private[spark] class SimrSchedulerBackend( // "Atomic" rename fs.rename(tmpPath, filePath) + + // Write Spark UI Address to file + val uiTemp = fs.create(tmpUiPath, true) + uiTemp.writeUTF(sc.ui.appUIAddress) + uiTemp.close() + fs.rename(tmpUiPath, uiPath) } override def stop() { diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 55b25f145a..e748c2275d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -27,13 +27,17 @@ import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar} import org.apache.spark.{SerializableWritable, Logging} import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId} +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.storage._ /** - * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. + * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. */ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging { - private val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 + + private val bufferSize = { + System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 + } def newKryoOutput() = new KryoOutput(bufferSize) @@ -42,21 +46,11 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging val kryo = instantiator.newKryo() val classLoader = Thread.currentThread.getContextClassLoader - val blockId = TestBlockId("1") - // Register some commonly used classes - val toRegister: Seq[AnyRef] = Seq( - ByteBuffer.allocate(1), - StorageLevel.MEMORY_ONLY, - PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), - GotBlock(blockId, ByteBuffer.allocate(1)), - GetBlock(blockId), - 1 to 10, - 1 until 10, - 1L to 10L, - 1L until 10L - ) - - for (obj <- toRegister) kryo.register(obj.getClass) + // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. + // Do this before we invoke the user registrator so the user registrator can override this. + kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean) + + for (cls <- KryoSerializer.toRegister) kryo.register(cls) // Allow sending SerializableWritable kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) @@ -78,10 +72,6 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging new AllScalaRegistrar().apply(kryo) kryo.setClassLoader(classLoader) - - // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops - kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean) - kryo } @@ -165,3 +155,21 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ trait KryoRegistrator { def registerClasses(kryo: Kryo) } + +private[serializer] object KryoSerializer { + // Commonly used classes. + private val toRegister: Seq[Class[_]] = Seq( + ByteBuffer.allocate(1).getClass, + classOf[StorageLevel], + classOf[PutBlock], + classOf[GotBlock], + classOf[GetBlock], + classOf[MapStatus], + classOf[BlockManagerId], + classOf[Array[Byte]], + (1 to 10).getClass, + (1 until 10).getClass, + (1L to 10L).getClass, + (1L until 10L).getClass + ) +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 252329c4e1..7e721a49a5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -892,9 +892,9 @@ private[spark] object BlockManager extends Logging { blockManagerMaster: BlockManagerMaster = null) : Map[BlockId, Seq[BlockManagerId]] = { - // env == null and blockManagerMaster != null is used in tests + // blockManagerMaster != null is used in tests assert (env != null || blockManagerMaster != null) - val blockLocations: Seq[Seq[BlockManagerId]] = if (env != null) { + val blockLocations: Seq[Seq[BlockManagerId]] = if (blockManagerMaster == null) { env.blockManager.getLocationBlockIds(blockIds) } else { blockManagerMaster.getLocations(blockIds) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 35b5d5fd59..c1c7aa70e6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -152,6 +152,22 @@ private[spark] class StagePage(parent: JobProgressUI) { else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("") val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L) + var shuffleReadSortable: String = "" + var shuffleReadReadable: String = "" + if (shuffleRead) { + shuffleReadSortable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}.toString() + shuffleReadReadable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => + Utils.bytesToString(s.remoteBytesRead)}.getOrElse("") + } + + var shuffleWriteSortable: String = "" + var shuffleWriteReadable: String = "" + if (shuffleWrite) { + shuffleWriteSortable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}.toString() + shuffleWriteReadable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => + Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("") + } + <tr> <td>{info.index}</td> <td>{info.taskId}</td> @@ -166,14 +182,17 @@ private[spark] class StagePage(parent: JobProgressUI) { {if (gcTime > 0) parent.formatDuration(gcTime) else ""} </td> {if (shuffleRead) { - <td>{metrics.flatMap{m => m.shuffleReadMetrics}.map{s => - Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td> + <td sorttable_customkey={shuffleReadSortable}> + {shuffleReadReadable} + </td> }} {if (shuffleWrite) { - <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}</td> - <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td> + <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => + parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")} + </td> + <td sorttable_customkey={shuffleWriteSortable}> + {shuffleWriteReadable} + </td> }} <td>{exception.map(e => <span> diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index d7d0441c38..9ad6de3c6d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -79,11 +79,14 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr case None => "Unknown" } - val shuffleRead = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) match { + val shuffleReadSortable = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) + val shuffleRead = shuffleReadSortable match { case 0 => "" case b => Utils.bytesToString(b) } - val shuffleWrite = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) match { + + val shuffleWriteSortable = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) + val shuffleWrite = shuffleWriteSortable match { case 0 => "" case b => Utils.bytesToString(b) } @@ -119,8 +122,8 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr <td class="progress-cell"> {makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)} </td> - <td>{shuffleRead}</td> - <td>{shuffleWrite}</td> + <td sorttable_customekey={shuffleReadSortable.toString}>{shuffleRead}</td> + <td sorttable_customekey={shuffleWriteSortable.toString}>{shuffleWrite}</td> </tr> } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index d433806987..8f0954122b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -5,10 +5,9 @@ import org.scalatest.FunSuite import org.apache.spark.deploy.{ExecutorState, Command, ApplicationDescription} class ExecutorRunnerTest extends FunSuite { - test("command includes appId") { def f(s:String) = new File(s) - val sparkHome = sys.props("user.dir") + val sparkHome = sys.env("SPARK_HOME") val appDesc = new ApplicationDescription("app name", 8, 500, Command("foo", Seq(),Map()), sparkHome, "appUiUrl") val appId = "12345-worker321-9876" diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 00f2fdd657..a4d41ebbff 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -100,7 +100,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont cacheLocations.clear() results.clear() mapOutputTracker = new MapOutputTrackerMaster() - scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) { + scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, sc.env) { override def runLocally(job: ActiveJob) { // don't bother with the thread while unit testing runLocallyWithinThread(job) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 2898af0bed..6fd1d0d150 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -21,6 +21,7 @@ The assembled JAR will be something like this: # Preparations - Building a YARN-enabled assembly (see above). +- The assembled jar can be installed into HDFS or used locally. - Your application code must be packaged into a separate JAR file. If you want to test out the YARN deployment mode, you can use the current Spark examples. A `spark-examples_{{site.SCALA_VERSION}}-{{site.SPARK_VERSION}}` file can be generated by running `sbt/sbt assembly`. NOTE: since the documentation you're reading is for Spark version {{site.SPARK_VERSION}}, we are assuming here that you have downloaded Spark {{site.SPARK_VERSION}} or checked it out of source control. If you are using a different version of Spark, the version numbers in the jar generated by the sbt package command will obviously be different. diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 79848380c0..1189232428 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -72,12 +72,12 @@ def parse_args(): parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use") parser.add_option("-v", "--spark-version", default="0.8.0", help="Version of Spark to use: 'X.Y.Z' or a specific git hash") - parser.add_option("--spark-git-repo", + parser.add_option("--spark-git-repo", default="https://github.com/apache/incubator-spark", help="Github repo from which to checkout supplied commit hash") parser.add_option("--hadoop-major-version", default="1", help="Major version of Hadoop (default: 1)") - parser.add_option("-D", metavar="[ADDRESS:]PORT", dest="proxy_port", + parser.add_option("-D", metavar="[ADDRESS:]PORT", dest="proxy_port", help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + "the given local address (for use with login)") parser.add_option("--resume", action="store_true", default=False, @@ -101,6 +101,8 @@ def parse_args(): help="The SSH user you want to connect as (default: root)") parser.add_option("--delete-groups", action="store_true", default=False, help="When destroying a cluster, delete the security groups that were created") + parser.add_option("--use-existing-master", action="store_true", default=False, + help="Launch fresh slaves, but use an existing stopped master if possible") (opts, args) = parser.parse_args() if len(args) != 2: @@ -191,7 +193,7 @@ def get_spark_ami(opts): instance_type = "pvm" print >> stderr,\ "Don't recognize %s, assuming type is pvm" % opts.instance_type - + ami_path = "%s/%s/%s" % (AMI_PREFIX, opts.region, instance_type) try: ami = urllib2.urlopen(ami_path).read().strip() @@ -215,6 +217,7 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize(src_group=slave_group) master_group.authorize('tcp', 22, 22, '0.0.0.0/0') master_group.authorize('tcp', 8080, 8081, '0.0.0.0/0') + master_group.authorize('tcp', 19999, 19999, '0.0.0.0/0') master_group.authorize('tcp', 50030, 50030, '0.0.0.0/0') master_group.authorize('tcp', 50070, 50070, '0.0.0.0/0') master_group.authorize('tcp', 60070, 60070, '0.0.0.0/0') @@ -232,9 +235,9 @@ def launch_cluster(conn, opts, cluster_name): slave_group.authorize('tcp', 60075, 60075, '0.0.0.0/0') # Check if instances are already running in our groups - active_nodes = get_existing_cluster(conn, opts, cluster_name, - die_on_error=False) - if any(active_nodes): + existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, + die_on_error=False) + if existing_slaves or (existing_masters and not opts.use_existing_master): print >> stderr, ("ERROR: There are already instances running in " + "group %s or %s" % (master_group.name, slave_group.name)) sys.exit(1) @@ -335,21 +338,28 @@ def launch_cluster(conn, opts, cluster_name): zone, slave_res.id) i += 1 - # Launch masters - master_type = opts.master_instance_type - if master_type == "": - master_type = opts.instance_type - if opts.zone == 'all': - opts.zone = random.choice(conn.get_all_zones()).name - master_res = image.run(key_name = opts.key_pair, - security_groups = [master_group], - instance_type = master_type, - placement = opts.zone, - min_count = 1, - max_count = 1, - block_device_map = block_map) - master_nodes = master_res.instances - print "Launched master in %s, regid = %s" % (zone, master_res.id) + # Launch or resume masters + if existing_masters: + print "Starting master..." + for inst in existing_masters: + if inst.state not in ["shutting-down", "terminated"]: + inst.start() + master_nodes = existing_masters + else: + master_type = opts.master_instance_type + if master_type == "": + master_type = opts.instance_type + if opts.zone == 'all': + opts.zone = random.choice(conn.get_all_zones()).name + master_res = image.run(key_name = opts.key_pair, + security_groups = [master_group], + instance_type = master_type, + placement = opts.zone, + min_count = 1, + max_count = 1, + block_device_map = block_map) + master_nodes = master_res.instances + print "Launched master in %s, regid = %s" % (zone, master_res.id) # Return all the instances return (master_nodes, slave_nodes) @@ -403,8 +413,8 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): print slave.public_dns_name ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar) - modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs', - 'mapreduce', 'spark-standalone'] + modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs', + 'mapreduce', 'spark-standalone', 'tachyon'] if opts.hadoop_major_version == "1": modules = filter(lambda x: x != "mapreduce", modules) @@ -668,12 +678,12 @@ def real_main(): print "Terminating slaves..." for inst in slave_nodes: inst.terminate() - + # Delete security groups as well if opts.delete_groups: print "Deleting security groups (this will take some time)..." group_names = [cluster_name + "-master", cluster_name + "-slaves"] - + attempt = 1; while attempt <= 3: print "Attempt %d" % attempt @@ -731,6 +741,7 @@ def real_main(): cluster_name + "?\nDATA ON EPHEMERAL DISKS WILL BE LOST, " + "BUT THE CLUSTER WILL KEEP USING SPACE ON\n" + "AMAZON EBS IF IT IS EBS-BACKED!!\n" + + "All data on spot-instance slaves will be lost.\n" + "Stop cluster " + cluster_name + " (y/N): ") if response == "y": (master_nodes, slave_nodes) = get_existing_cluster( @@ -742,7 +753,10 @@ def real_main(): print "Stopping slaves..." for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: - inst.stop() + if inst.spot_instance_request_id: + inst.terminate() + else: + inst.stop() elif action == "start": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) @@ -391,6 +391,12 @@ <version>3.1</version> <scope>test</scope> </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-all</artifactId> + <version>1.8.5</version> + <scope>test</scope> + </dependency> <dependency> <groupId>org.scalacheck</groupId> <artifactId>scalacheck_${scala-short.version}</artifactId> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b71e1b3a56..9a3cbbe7d2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -172,6 +172,7 @@ object SparkBuild extends Build { "org.scalacheck" %% "scalacheck" % "1.10.0" % "test", "com.novocode" % "junit-interface" % "0.9" % "test", "org.easymock" % "easymock" % "3.1" % "test", + "org.mockito" % "mockito-all" % "1.8.5" % "test", "commons-io" % "commons-io" % "2.4" % "test" ), @@ -268,7 +269,7 @@ object SparkBuild extends Build { def toolsSettings = sharedSettings ++ Seq( name := "spark-tools" - ) + ) ++ assemblySettings ++ extraAssemblySettings def bagelSettings = sharedSettings ++ Seq( name := "spark-bagel" @@ -333,7 +334,7 @@ object SparkBuild extends Build { case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard case "log4j.properties" => MergeStrategy.discard - case "META-INF/services/org.apache.hadoop.fs.FileSystem" => MergeStrategy.concat + case m if m.toLowerCase.startsWith("meta-inf/services/") => MergeStrategy.filterDistinctLines case "reference.conf" => MergeStrategy.concat case _ => MergeStrategy.first } diff --git a/project/plugins.sbt b/project/plugins.sbt index cfcd85082a..4ba0e4280a 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -4,7 +4,7 @@ resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/release resolvers += "Spray Repository" at "http://repo.spray.cc/" -addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.9.1") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.9.2") addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index fccb6e652c..418c31e24b 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -5,10 +5,13 @@ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite import com.google.common.io.Files +import org.scalatest.FunSuite +import org.apache.spark.SparkContext + class ReplSuite extends FunSuite { + def runInterpreter(master: String, input: String): String = { val in = new BufferedReader(new StringReader(input + "\n")) val out = new StringWriter() @@ -46,7 +49,36 @@ class ReplSuite extends FunSuite { "Interpreter output contained '" + message + "':\n" + output) } - test("simple foreach with accumulator") { + test("propagation of local properties") { + // A mock ILoop that doesn't install the SIGINT handler. + class ILoop(out: PrintWriter) extends SparkILoop(None, out, None) { + settings = new scala.tools.nsc.Settings + settings.usejavacp.value = true + org.apache.spark.repl.Main.interp = this + override def createInterpreter() { + intp = new SparkILoopInterpreter + intp.setContextClassLoader() + } + } + + val out = new StringWriter() + val interp = new ILoop(new PrintWriter(out)) + interp.sparkContext = new SparkContext("local", "repl-test") + interp.createInterpreter() + interp.intp.initialize() + interp.sparkContext.setLocalProperty("someKey", "someValue") + + // Make sure the value we set in the caller to interpret is propagated in the thread that + // interprets the command. + interp.interpret("org.apache.spark.repl.Main.interp.sparkContext.getLocalProperty(\"someKey\")") + assert(out.toString.contains("someValue")) + + interp.sparkContext.stop() + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + } + + test ("simple foreach with accumulator") { val output = runInterpreter("local", """ |val accum = sc.accumulator(0) |sc.parallelize(1 to 10).foreach(x => accum += x) diff --git a/spark-class b/spark-class index 359db3d984..78d6e073b1 100755 --- a/spark-class +++ b/spark-class @@ -110,8 +110,21 @@ if [ ! -f "$FWDIR/RELEASE" ]; then fi fi +TOOLS_DIR="$FWDIR"/tools +SPARK_TOOLS_JAR="" +if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar ]; then + # Use the JAR from the SBT build + export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar` +fi +if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then + # Use the JAR from the Maven build + # TODO: this also needs to become an assembly! + export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar` +fi + # Compute classpath using external script CLASSPATH=`$FWDIR/bin/compute-classpath.sh` +CLASSPATH="$SPARK_TOOLS_JAR:$CLASSPATH" export CLASSPATH if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then diff --git a/spark-class2.cmd b/spark-class2.cmd index d4d853e8ad..3869d0761b 100644 --- a/spark-class2.cmd +++ b/spark-class2.cmd @@ -65,10 +65,17 @@ if "%FOUND_JAR%"=="0" ( ) :skip_build_test +set TOOLS_DIR=%FWDIR%tools +set SPARK_TOOLS_JAR= +for %%d in ("%TOOLS_DIR%\target\scala-%SCALA_VERSION%\spark-tools*assembly*.jar") do ( + set SPARK_TOOLS_JAR=%%d +) + rem Compute classpath using external script set DONT_PRINT_CLASSPATH=1 call "%FWDIR%bin\compute-classpath.cmd" set DONT_PRINT_CLASSPATH=0 +set CLASSPATH=%SPARK_TOOLS_JAR%;%CLASSPATH% rem Figure out where java is. set RUNNER=java diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index ab97ee9349..e90557d9b7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -233,11 +233,11 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging logInfo("Data handler stopped") } - def += (obj: T) { + def += (obj: T): Unit = synchronized { currentBuffer += obj } - private def updateCurrentBuffer(time: Long) { + private def updateCurrentBuffer(time: Long): Unit = synchronized { try { val newBlockBuffer = currentBuffer currentBuffer = new ArrayBuffer[T] diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index c29b75ece6..a559db468a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -23,15 +23,15 @@ import akka.actor.IOManager import akka.actor.Props import akka.util.ByteString -import dstream.SparkFlumeEvent +import org.apache.spark.streaming.dstream.{NetworkReceiver, SparkFlumeEvent} import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket} import java.io.{File, BufferedWriter, OutputStreamWriter} -import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} +import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue} import collection.mutable.{SynchronizedBuffer, ArrayBuffer} import util.ManualClock import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.receivers.Receiver -import org.apache.spark.Logging +import org.apache.spark.{SparkContext, Logging} import scala.util.Random import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfter @@ -44,6 +44,7 @@ import java.nio.ByteBuffer import collection.JavaConversions._ import java.nio.charset.Charset import com.google.common.io.Files +import java.util.concurrent.atomic.AtomicInteger class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -61,7 +62,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { System.clearProperty("spark.hostPort") } - test("socket input stream") { // Start the server val testServer = new TestServer() @@ -275,10 +275,49 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { kafka.serializer.StringDecoder, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK) } + + test("multi-thread receiver") { + // set up the test receiver + val numThreads = 10 + val numRecordsPerThread = 1000 + val numTotalRecords = numThreads * numRecordsPerThread + val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread) + MultiThreadTestReceiver.haveAllThreadsFinished = false + + // set up the network stream using the test receiver + val ssc = new StreamingContext(master, framework, batchDuration) + val networkStream = ssc.networkStream[Int](testReceiver) + val countStream = networkStream.count + val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] + val outputStream = new TestOutputStream(countStream, outputBuffer) + def output = outputBuffer.flatMap(x => x) + ssc.registerOutputStream(outputStream) + ssc.start() + + // Let the data from the receiver be received + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val startTime = System.currentTimeMillis() + while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && + System.currentTimeMillis() - startTime < 5000) { + Thread.sleep(100) + clock.addToTime(batchDuration.milliseconds) + } + Thread.sleep(1000) + logInfo("Stopping context") + ssc.stop() + + // Verify whether data received was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + assert(output.sum === numTotalRecords) + } } -/** This is server to test the network input stream */ +/** This is a server to test the network input stream */ class TestServer() extends Logging { val queue = new ArrayBlockingQueue[String](100) @@ -340,6 +379,7 @@ object TestServer { } } +/** This is an actor for testing actor input stream */ class TestActor(port: Int) extends Actor with Receiver { def bytesToString(byteString: ByteString) = byteString.utf8String @@ -351,3 +391,36 @@ class TestActor(port: Int) extends Actor with Receiver { pushBlock(bytesToString(bytes)) } } + +/** This is a receiver to test multiple threads inserting data using block generator */ +class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) + extends NetworkReceiver[Int] { + lazy val executorPool = Executors.newFixedThreadPool(numThreads) + lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY) + lazy val finishCount = new AtomicInteger(0) + + protected def onStart() { + blockGenerator.start() + (1 to numThreads).map(threadId => { + val runnable = new Runnable { + def run() { + (1 to numRecordsPerThread).foreach(i => + blockGenerator += (threadId * numRecordsPerThread + i) ) + if (finishCount.incrementAndGet == numThreads) { + MultiThreadTestReceiver.haveAllThreadsFinished = true + } + logInfo("Finished thread " + threadId) + } + } + executorPool.submit(runnable) + }) + } + + protected def onStop() { + executorPool.shutdown() + } +} + +object MultiThreadTestReceiver { + var haveAllThreadsFinished = false +} diff --git a/yarn/pom.xml b/yarn/pom.xml index 7770cbb0cc..12bc97da8a 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -61,6 +61,16 @@ <groupId>org.apache.avro</groupId> <artifactId>avro-ipc</artifactId> </dependency> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_2.9.3</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-all</artifactId> + <scope>test</scope> + </dependency> </dependencies> <build> @@ -106,6 +116,46 @@ </execution> </executions> </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-antrun-plugin</artifactId> + <executions> + <execution> + <phase>test</phase> + <goals> + <goal>run</goal> + </goals> + <configuration> + <exportAntProperties>true</exportAntProperties> + <tasks> + <property name="spark.classpath" refid="maven.test.classpath" /> + <property environment="env" /> + <fail message="Please set the SCALA_HOME (or SCALA_LIBRARY_PATH if scala is on the path) environment variables and retry."> + <condition> + <not> + <or> + <isset property="env.SCALA_HOME" /> + <isset property="env.SCALA_LIBRARY_PATH" /> + </or> + </not> + </condition> + </fail> + </tasks> + </configuration> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + <configuration> + <environmentVariables> + <SPARK_HOME>${basedir}/..</SPARK_HOME> + <SPARK_TESTING>1</SPARK_TESTING> + <SPARK_CLASSPATH>${spark.classpath}</SPARK_CLASSPATH> + </environmentVariables> + </configuration> + </plugin> </plugins> </build> </project> diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index c1a87d3373..4302ef4cda 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -349,7 +349,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e try { val preserveFiles = System.getProperty("spark.yarn.preserve.staging.files", "false").toBoolean if (!preserveFiles) { - stagingDirPath = new Path(System.getenv("SPARK_YARN_JAR_PATH")).getParent() + stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR")) if (stagingDirPath == null) { logError("Staging directory is null") return diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 1a380ae714..4e0e060ddc 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,26 +17,31 @@ package org.apache.spark.deploy.yarn -import java.net.{InetSocketAddress, URI} +import java.net.{InetAddress, InetSocketAddress, UnknownHostException, URI} import java.nio.ByteBuffer + import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileContext, FileStatus, FileSystem, Path, FileUtil} +import org.apache.hadoop.fs.permission.FsPermission; import org.apache.hadoop.mapred.Master import org.apache.hadoop.net.NetUtils import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.client.YarnClientImpl import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{Apps, Records} + import scala.collection.mutable.HashMap +import scala.collection.mutable.Map import scala.collection.JavaConversions._ + import org.apache.spark.Logging import org.apache.spark.util.Utils -import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils} -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.spark.deploy.SparkHadoopUtil class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging { @@ -46,13 +51,14 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl var rpc: YarnRPC = YarnRPC.create(conf) val yarnConf: YarnConfiguration = new YarnConfiguration(conf) val credentials = UserGroupInformation.getCurrentUser().getCredentials() - private var distFiles = None: Option[String] - private var distFilesTimeStamps = None: Option[String] - private var distFilesFileSizes = None: Option[String] - private var distArchives = None: Option[String] - private var distArchivesTimeStamps = None: Option[String] - private var distArchivesFileSizes = None: Option[String] - + private val SPARK_STAGING: String = ".sparkStaging" + private val distCacheMgr = new ClientDistributedCacheManager() + + // staging directory is private! -> rwx-------- + val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700:Short) + // app files are world-wide readable and owner writable -> rw-r--r-- + val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short) + def run() { init(yarnConf) start() @@ -63,8 +69,9 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl verifyClusterResources(newApp) val appContext = createApplicationSubmissionContext(appId) - val localResources = prepareLocalResources(appId, ".sparkStaging") - val env = setupLaunchEnv(localResources) + val appStagingDir = getAppStagingDir(appId) + val localResources = prepareLocalResources(appStagingDir) + val env = setupLaunchEnv(localResources, appStagingDir) val amContainer = createContainerLaunchContext(newApp, localResources, env) appContext.setQueue(args.amQueue) @@ -76,7 +83,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl monitorApplication(appId) System.exit(0) } - + + def getAppStagingDir(appId: ApplicationId): String = { + SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR + } def logClusterResourceDetails() { val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics @@ -116,73 +126,73 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl return appContext } + /* + * see if two file systems are the same or not. + */ + private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { + val srcUri = srcFs.getUri() + val dstUri = destFs.getUri() + if (srcUri.getScheme() == null) { + return false + } + if (!srcUri.getScheme().equals(dstUri.getScheme())) { + return false + } + var srcHost = srcUri.getHost() + var dstHost = dstUri.getHost() + if ((srcHost != null) && (dstHost != null)) { + try { + srcHost = InetAddress.getByName(srcHost).getCanonicalHostName(); + dstHost = InetAddress.getByName(dstHost).getCanonicalHostName(); + } catch { + case e: UnknownHostException => + return false + } + if (!srcHost.equals(dstHost)) { + return false + } + } else if (srcHost == null && dstHost != null) { + return false + } else if (srcHost != null && dstHost == null) { + return false + } + //check for ports + if (srcUri.getPort() != dstUri.getPort()) { + return false + } + return true; + } + /** - * Copy the local file into HDFS and configure to be distributed with the - * job via the distributed cache. - * If a fragment is specified the file will be referenced as that fragment. + * Copy the file into HDFS if needed. */ - private def copyLocalFile( + private def copyRemoteFile( dstDir: Path, - resourceType: LocalResourceType, originalPath: Path, replication: Short, - localResources: HashMap[String,LocalResource], - fragment: String, - appMasterOnly: Boolean = false): Unit = { + setPerms: Boolean = false): Path = { val fs = FileSystem.get(conf) - val newPath = new Path(dstDir, originalPath.getName()) - logInfo("Uploading " + originalPath + " to " + newPath) - fs.copyFromLocalFile(false, true, originalPath, newPath) - fs.setReplication(newPath, replication); - val destStatus = fs.getFileStatus(newPath) - - val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] - amJarRsrc.setType(resourceType) - amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION) - amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(newPath)) - amJarRsrc.setTimestamp(destStatus.getModificationTime()) - amJarRsrc.setSize(destStatus.getLen()) - var pathURI: URI = new URI(newPath.toString() + "#" + originalPath.getName()); - if ((fragment == null) || (fragment.isEmpty())){ - localResources(originalPath.getName()) = amJarRsrc - } else { - localResources(fragment) = amJarRsrc - pathURI = new URI(newPath.toString() + "#" + fragment); - } - val distPath = pathURI.toString() - if (appMasterOnly == true) return - if (resourceType == LocalResourceType.FILE) { - distFiles match { - case Some(path) => - distFilesFileSizes = Some(distFilesFileSizes.get + "," + - destStatus.getLen().toString()) - distFilesTimeStamps = Some(distFilesTimeStamps.get + "," + - destStatus.getModificationTime().toString()) - distFiles = Some(path + "," + distPath) - case _ => - distFilesFileSizes = Some(destStatus.getLen().toString()) - distFilesTimeStamps = Some(destStatus.getModificationTime().toString()) - distFiles = Some(distPath) - } - } else { - distArchives match { - case Some(path) => - distArchivesTimeStamps = Some(distArchivesTimeStamps.get + "," + - destStatus.getModificationTime().toString()) - distArchivesFileSizes = Some(distArchivesFileSizes.get + "," + - destStatus.getLen().toString()) - distArchives = Some(path + "," + distPath) - case _ => - distArchivesTimeStamps = Some(destStatus.getModificationTime().toString()) - distArchivesFileSizes = Some(destStatus.getLen().toString()) - distArchives = Some(distPath) - } - } + val remoteFs = originalPath.getFileSystem(conf); + var newPath = originalPath + if (! compareFs(remoteFs, fs)) { + newPath = new Path(dstDir, originalPath.getName()) + logInfo("Uploading " + originalPath + " to " + newPath) + FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf); + fs.setReplication(newPath, replication); + if (setPerms) fs.setPermission(newPath, new FsPermission(APP_FILE_PERMISSION)) + } + // resolve any symlinks in the URI path so using a "current" symlink + // to point to a specific version shows the specific version + // in the distributed cache configuration + val qualPath = fs.makeQualified(newPath) + val fc = FileContext.getFileContext(qualPath.toUri(), conf) + val destPath = fc.resolvePath(qualPath) + destPath } - def prepareLocalResources(appId: ApplicationId, sparkStagingDir: String): HashMap[String, LocalResource] = { + def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { logInfo("Preparing Local resources") - // Upload Spark and the application JAR to the remote file system + // Upload Spark and the application JAR to the remote file system if necessary // Add them as local resources to the AM val fs = FileSystem.get(conf) @@ -193,9 +203,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl System.exit(1) } } - - val pathSuffix = sparkStagingDir + "/" + appId.toString() + "/" - val dst = new Path(fs.getHomeDirectory(), pathSuffix) + val dst = new Path(fs.getHomeDirectory(), appStagingDir) val replication = System.getProperty("spark.yarn.submit.file.replication", "3").toShort if (UserGroupInformation.isSecurityEnabled()) { @@ -203,55 +211,65 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl dstFs.addDelegationTokens(delegTokenRenewer, credentials); } val localResources = HashMap[String, LocalResource]() + FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION)) + + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + + if (System.getenv("SPARK_JAR") == null || args.userJar == null) { + logError("Error: You must set SPARK_JAR environment variable and specify a user jar!") + System.exit(1) + } - Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF")) + Map(Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar, + Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF")) .foreach { case(destName, _localPath) => val localPath: String = if (_localPath != null) _localPath.trim() else "" if (! localPath.isEmpty()) { - val src = new Path(localPath) - val newPath = new Path(dst, destName) - logInfo("Uploading " + src + " to " + newPath) - fs.copyFromLocalFile(false, true, src, newPath) - fs.setReplication(newPath, replication); - val destStatus = fs.getFileStatus(newPath) - - val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] - amJarRsrc.setType(LocalResourceType.FILE) - amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION) - amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(newPath)) - amJarRsrc.setTimestamp(destStatus.getModificationTime()) - amJarRsrc.setSize(destStatus.getLen()) - localResources(destName) = amJarRsrc + var localURI = new URI(localPath) + // if not specified assume these are in the local filesystem to keep behavior like Hadoop + if (localURI.getScheme() == null) { + localURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(localPath)).toString()) + } + val setPermissions = if (destName.equals(Client.APP_JAR)) true else false + val destPath = copyRemoteFile(dst, new Path(localURI), replication, setPermissions) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + destName, statCache) } } // handle any add jars if ((args.addJars != null) && (!args.addJars.isEmpty())){ args.addJars.split(',').foreach { case file: String => - val tmpURI = new URI(file) - val tmp = new Path(tmpURI) - copyLocalFile(dst, LocalResourceType.FILE, tmp, replication, localResources, - tmpURI.getFragment(), true) + val localURI = new URI(file.trim()) + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyRemoteFile(dst, localPath, replication) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + linkname, statCache, true) } } // handle any distributed cache files if ((args.files != null) && (!args.files.isEmpty())){ args.files.split(',').foreach { case file: String => - val tmpURI = new URI(file) - val tmp = new Path(tmpURI) - copyLocalFile(dst, LocalResourceType.FILE, tmp, replication, localResources, - tmpURI.getFragment()) + val localURI = new URI(file.trim()) + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyRemoteFile(dst, localPath, replication) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + linkname, statCache) } } // handle any distributed cache archives if ((args.archives != null) && (!args.archives.isEmpty())) { args.archives.split(',').foreach { case file:String => - val tmpURI = new URI(file) - val tmp = new Path(tmpURI) - copyLocalFile(dst, LocalResourceType.ARCHIVE, tmp, replication, - localResources, tmpURI.getFragment()) + val localURI = new URI(file.trim()) + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyRemoteFile(dst, localPath, replication) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, + linkname, statCache) } } @@ -259,44 +277,21 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl return localResources } - def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = { + def setupLaunchEnv( + localResources: HashMap[String, LocalResource], + stagingDir: String): HashMap[String, String] = { logInfo("Setting up the launch environment") - val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null) + val log4jConfLocalRes = localResources.getOrElse(Client.LOG4J_PROP, null) val env = new HashMap[String, String]() Client.populateClasspath(yarnConf, log4jConfLocalRes != null, env) env("SPARK_YARN_MODE") = "true" - env("SPARK_YARN_JAR_PATH") = - localResources("spark.jar").getResource().getScheme.toString() + "://" + - localResources("spark.jar").getResource().getFile().toString() - env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString() - env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString() - - env("SPARK_YARN_USERJAR_PATH") = - localResources("app.jar").getResource().getScheme.toString() + "://" + - localResources("app.jar").getResource().getFile().toString() - env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString() - env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString() - - if (log4jConfLocalRes != null) { - env("SPARK_YARN_LOG4J_PATH") = - log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString() - env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString() - env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString() - } + env("SPARK_YARN_STAGING_DIR") = stagingDir // set the environment variables to be passed on to the Workers - if (distFiles != None) { - env("SPARK_YARN_CACHE_FILES") = distFiles.get - env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = distFilesTimeStamps.get - env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = distFilesFileSizes.get - } - if (distArchives != None) { - env("SPARK_YARN_CACHE_ARCHIVES") = distArchives.get - env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = distArchivesTimeStamps.get - env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = distArchivesFileSizes.get - } + distCacheMgr.setDistFilesEnv(env) + distCacheMgr.setDistArchivesEnv(env) // allow users to specify some environment variables Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV")) @@ -365,6 +360,11 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl javaCommand = Environment.JAVA_HOME.$() + "/bin/java" } + if (args.userClass == null) { + logError("Error: You must specify a user class!") + System.exit(1) + } + val commands = List[String](javaCommand + " -server " + JAVA_OPTS + @@ -432,6 +432,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } object Client { + val SPARK_JAR: String = "spark.jar" + val APP_JAR: String = "app.jar" + val LOG4J_PROP: String = "log4j.properties" + def main(argStrings: Array[String]) { // Set an env variable indicating we are running in YARN mode. // Note that anything with SPARK prefix gets propagated to all (remote) processes @@ -453,22 +457,22 @@ object Client { // If log4j present, ensure ours overrides all others if (addLog4j) { Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + - Path.SEPARATOR + "log4j.properties") + Path.SEPARATOR + LOG4J_PROP) } // normally the users app.jar is last in case conflicts with spark jars val userClasspathFirst = System.getProperty("spark.yarn.user.classpath.first", "false") .toBoolean if (userClasspathFirst) { Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + - Path.SEPARATOR + "app.jar") + Path.SEPARATOR + APP_JAR) } Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + - Path.SEPARATOR + "spark.jar") + Path.SEPARATOR + SPARK_JAR) Client.populateHadoopClasspath(conf, env) if (!userClasspathFirst) { Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + - Path.SEPARATOR + "app.jar") + Path.SEPARATOR + APP_JAR) } Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + Path.SEPARATOR + "*") diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala new file mode 100644 index 0000000000..07686fefd7 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.URI; + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.permission.FsAction +import org.apache.hadoop.yarn.api.records.LocalResource +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility +import org.apache.hadoop.yarn.api.records.LocalResourceType +import org.apache.hadoop.yarn.util.{Records, ConverterUtils} + +import org.apache.spark.Logging + +import scala.collection.mutable.HashMap +import scala.collection.mutable.LinkedHashMap +import scala.collection.mutable.Map + + +/** Client side methods to setup the Hadoop distributed cache */ +class ClientDistributedCacheManager() extends Logging { + private val distCacheFiles: Map[String, Tuple3[String, String, String]] = + LinkedHashMap[String, Tuple3[String, String, String]]() + private val distCacheArchives: Map[String, Tuple3[String, String, String]] = + LinkedHashMap[String, Tuple3[String, String, String]]() + + + /** + * Add a resource to the list of distributed cache resources. This list can + * be sent to the ApplicationMaster and possibly the workers so that it can + * be downloaded into the Hadoop distributed cache for use by this application. + * Adds the LocalResource to the localResources HashMap passed in and saves + * the stats of the resources to they can be sent to the workers and verified. + * + * @param fs FileSystem + * @param conf Configuration + * @param destPath path to the resource + * @param localResources localResource hashMap to insert the resource into + * @param resourceType LocalResourceType + * @param link link presented in the distributed cache to the destination + * @param statCache cache to store the file/directory stats + * @param appMasterOnly Whether to only add the resource to the app master + */ + def addResource( + fs: FileSystem, + conf: Configuration, + destPath: Path, + localResources: HashMap[String, LocalResource], + resourceType: LocalResourceType, + link: String, + statCache: Map[URI, FileStatus], + appMasterOnly: Boolean = false) = { + val destStatus = fs.getFileStatus(destPath) + val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + amJarRsrc.setType(resourceType) + val visibility = getVisibility(conf, destPath.toUri(), statCache) + amJarRsrc.setVisibility(visibility) + amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(destPath)) + amJarRsrc.setTimestamp(destStatus.getModificationTime()) + amJarRsrc.setSize(destStatus.getLen()) + if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name") + localResources(link) = amJarRsrc + + if (appMasterOnly == false) { + val uri = destPath.toUri() + val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link) + if (resourceType == LocalResourceType.FILE) { + distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), + destStatus.getModificationTime().toString(), visibility.name()) + } else { + distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), + destStatus.getModificationTime().toString(), visibility.name()) + } + } + } + + /** + * Adds the necessary cache file env variables to the env passed in + * @param env + */ + def setDistFilesEnv(env: Map[String, String]) = { + val (keys, tupleValues) = distCacheFiles.unzip + val (sizes, timeStamps, visibilities) = tupleValues.unzip3 + + if (keys.size > 0) { + env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = + sizes.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + } + } + + /** + * Adds the necessary cache archive env variables to the env passed in + * @param env + */ + def setDistArchivesEnv(env: Map[String, String]) = { + val (keys, tupleValues) = distCacheArchives.unzip + val (sizes, timeStamps, visibilities) = tupleValues.unzip3 + + if (keys.size > 0) { + env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = + sizes.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + } + } + + /** + * Returns the local resource visibility depending on the cache file permissions + * @param conf + * @param uri + * @param statCache + * @return LocalResourceVisibility + */ + def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): + LocalResourceVisibility = { + if (isPublic(conf, uri, statCache)) { + return LocalResourceVisibility.PUBLIC + } + return LocalResourceVisibility.PRIVATE + } + + /** + * Returns a boolean to denote whether a cache file is visible to all(public) + * or not + * @param conf + * @param uri + * @param statCache + * @return true if the path in the uri is visible to all, false otherwise + */ + def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = { + val fs = FileSystem.get(uri, conf) + val current = new Path(uri.getPath()) + //the leaf level file should be readable by others + if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) { + return false + } + return ancestorsHaveExecutePermissions(fs, current.getParent(), statCache) + } + + /** + * Returns true if all ancestors of the specified path have the 'execute' + * permission set for all users (i.e. that other users can traverse + * the directory heirarchy to the given path) + * @param fs + * @param path + * @param statCache + * @return true if all ancestors have the 'execute' permission set for all users + */ + def ancestorsHaveExecutePermissions(fs: FileSystem, path: Path, + statCache: Map[URI, FileStatus]): Boolean = { + var current = path + while (current != null) { + //the subdirs in the path should have execute permissions for others + if (!checkPermissionOfOther(fs, current, FsAction.EXECUTE, statCache)) { + return false + } + current = current.getParent() + } + return true + } + + /** + * Checks for a given path whether the Other permissions on it + * imply the permission in the passed FsAction + * @param fs + * @param path + * @param action + * @param statCache + * @return true if the path in the uri is visible to all, false otherwise + */ + def checkPermissionOfOther(fs: FileSystem, path: Path, + action: FsAction, statCache: Map[URI, FileStatus]): Boolean = { + val status = getFileStatus(fs, path.toUri(), statCache); + val perms = status.getPermission() + val otherAction = perms.getOtherAction() + if (otherAction.implies(action)) { + return true; + } + return false + } + + /** + * Checks to see if the given uri exists in the cache, if it does it + * returns the existing FileStatus, otherwise it stats the uri, stores + * it in the cache, and returns the FileStatus. + * @param fs + * @param uri + * @param statCache + * @return FileStatus + */ + def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = { + val stat = statCache.get(uri) match { + case Some(existstat) => existstat + case None => + val newStat = fs.getFileStatus(new Path(uri)) + statCache.put(uri, newStat) + newStat + } + return stat + } +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala index ba352daac4..7a66532254 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala @@ -142,11 +142,12 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S rtype: LocalResourceType, localResources: HashMap[String, LocalResource], timestamp: String, - size: String) = { + size: String, + vis: String) = { val uri = new URI(file) val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] amJarRsrc.setType(rtype) - amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION) + amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis)) amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri)) amJarRsrc.setTimestamp(timestamp.toLong) amJarRsrc.setSize(size.toLong) @@ -158,44 +159,14 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S logInfo("Preparing Local resources") val localResources = HashMap[String, LocalResource]() - // Spark JAR - val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] - sparkJarResource.setType(LocalResourceType.FILE) - sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION) - sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI( - new URI(System.getenv("SPARK_YARN_JAR_PATH")))) - sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong) - sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong) - localResources("spark.jar") = sparkJarResource - // User JAR - val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] - userJarResource.setType(LocalResourceType.FILE) - userJarResource.setVisibility(LocalResourceVisibility.APPLICATION) - userJarResource.setResource(ConverterUtils.getYarnUrlFromURI( - new URI(System.getenv("SPARK_YARN_USERJAR_PATH")))) - userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong) - userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong) - localResources("app.jar") = userJarResource - - // Log4j conf - if available - if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) { - val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] - log4jConfResource.setType(LocalResourceType.FILE) - log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION) - log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI( - new URI(System.getenv("SPARK_YARN_LOG4J_PATH")))) - log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong) - log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong) - localResources("log4j.properties") = log4jConfResource - } - if (System.getenv("SPARK_YARN_CACHE_FILES") != null) { val timeStamps = System.getenv("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') val fileSizes = System.getenv("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') val distFiles = System.getenv("SPARK_YARN_CACHE_FILES").split(',') + val visibilities = System.getenv("SPARK_YARN_CACHE_FILES_VISIBILITIES").split(',') for( i <- 0 to distFiles.length - 1) { setupDistributedCache(distFiles(i), LocalResourceType.FILE, localResources, timeStamps(i), - fileSizes(i)) + fileSizes(i), visibilities(i)) } } @@ -203,9 +174,10 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S val timeStamps = System.getenv("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS").split(',') val fileSizes = System.getenv("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES").split(',') val distArchives = System.getenv("SPARK_YARN_CACHE_ARCHIVES").split(',') + val visibilities = System.getenv("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES").split(',') for( i <- 0 to distArchives.length - 1) { setupDistributedCache(distArchives(i), LocalResourceType.ARCHIVE, localResources, - timeStamps(i), fileSizes(i)) + timeStamps(i), fileSizes(i), visibilities(i)) } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala new file mode 100644 index 0000000000..c0a2af0c6f --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.URI; + +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar +import org.mockito.Mockito.when + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.permission.FsAction +import org.apache.hadoop.yarn.api.records.LocalResource +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility +import org.apache.hadoop.yarn.api.records.LocalResourceType +import org.apache.hadoop.yarn.util.{Records, ConverterUtils} + +import scala.collection.mutable.HashMap +import scala.collection.mutable.Map + + +class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { + + class MockClientDistributedCacheManager extends ClientDistributedCacheManager { + override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): + LocalResourceVisibility = { + return LocalResourceVisibility.PRIVATE + } + } + + test("test getFileStatus empty") { + val distMgr = new ClientDistributedCacheManager() + val fs = mock[FileSystem] + val uri = new URI("/tmp/testing") + when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus()) + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + val stat = distMgr.getFileStatus(fs, uri, statCache) + assert(stat.getPath() === null) + } + + test("test getFileStatus cached") { + val distMgr = new ClientDistributedCacheManager() + val fs = mock[FileSystem] + val uri = new URI("/tmp/testing") + val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", + null, new Path("/tmp/testing")) + when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus()) + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus) + val stat = distMgr.getFileStatus(fs, uri, statCache) + assert(stat.getPath().toString() === "/tmp/testing") + } + + test("test addResource") { + val distMgr = new MockClientDistributedCacheManager() + val fs = mock[FileSystem] + val conf = new Configuration() + val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") + val localResources = HashMap[String, LocalResource]() + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) + + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", + statCache, false) + val resource = localResources("link") + assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) + assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath) + assert(resource.getTimestamp() === 0) + assert(resource.getSize() === 0) + assert(resource.getType() === LocalResourceType.FILE) + + val env = new HashMap[String, String]() + distMgr.setDistFilesEnv(env) + assert(env("SPARK_YARN_CACHE_FILES") === "file:/foo.invalid.com:8080/tmp/testing#link") + assert(env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === "0") + assert(env("SPARK_YARN_CACHE_FILES_FILE_SIZES") === "0") + assert(env("SPARK_YARN_CACHE_FILES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name()) + + distMgr.setDistArchivesEnv(env) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) + + //add another one and verify both there and order correct + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + null, new Path("/tmp/testing2")) + val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2") + when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus) + distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", + statCache, false) + val resource2 = localResources("link2") + assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE) + assert(ConverterUtils.getPathFromYarnURL(resource2.getResource()) === destPath2) + assert(resource2.getTimestamp() === 10) + assert(resource2.getSize() === 20) + assert(resource2.getType() === LocalResourceType.FILE) + + val env2 = new HashMap[String, String]() + distMgr.setDistFilesEnv(env2) + val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') + val files = env2("SPARK_YARN_CACHE_FILES").split(',') + val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') + val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',') + assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link") + assert(timestamps(0) === "0") + assert(sizes(0) === "0") + assert(visibilities(0) === LocalResourceVisibility.PRIVATE.name()) + + assert(files(1) === "file:/foo.invalid.com:8080/tmp/testing2#link2") + assert(timestamps(1) === "10") + assert(sizes(1) === "20") + assert(visibilities(1) === LocalResourceVisibility.PRIVATE.name()) + } + + test("test addResource link null") { + val distMgr = new MockClientDistributedCacheManager() + val fs = mock[FileSystem] + val conf = new Configuration() + val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") + val localResources = HashMap[String, LocalResource]() + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) + + intercept[Exception] { + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, + statCache, false) + } + assert(localResources.get("link") === None) + assert(localResources.size === 0) + } + + test("test addResource appmaster only") { + val distMgr = new MockClientDistributedCacheManager() + val fs = mock[FileSystem] + val conf = new Configuration() + val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") + val localResources = HashMap[String, LocalResource]() + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + null, new Path("/tmp/testing")) + when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) + + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + statCache, true) + val resource = localResources("link") + assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) + assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath) + assert(resource.getTimestamp() === 10) + assert(resource.getSize() === 20) + assert(resource.getType() === LocalResourceType.ARCHIVE) + + val env = new HashMap[String, String]() + distMgr.setDistFilesEnv(env) + assert(env.get("SPARK_YARN_CACHE_FILES") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None) + + distMgr.setDistArchivesEnv(env) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) + } + + test("test addResource archive") { + val distMgr = new MockClientDistributedCacheManager() + val fs = mock[FileSystem] + val conf = new Configuration() + val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") + val localResources = HashMap[String, LocalResource]() + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + null, new Path("/tmp/testing")) + when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) + + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + statCache, false) + val resource = localResources("link") + assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) + assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath) + assert(resource.getTimestamp() === 10) + assert(resource.getSize() === 20) + assert(resource.getType() === LocalResourceType.ARCHIVE) + + val env = new HashMap[String, String]() + + distMgr.setDistArchivesEnv(env) + assert(env("SPARK_YARN_CACHE_ARCHIVES") === "file:/foo.invalid.com:8080/tmp/testing#link") + assert(env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === "10") + assert(env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === "20") + assert(env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name()) + + distMgr.setDistFilesEnv(env) + assert(env.get("SPARK_YARN_CACHE_FILES") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None) + } + + +} |