aboutsummaryrefslogtreecommitdiff
path: root/yarn/common
diff options
context:
space:
mode:
authorRaymond Liu <raymond.liu@intel.com>2013-12-04 13:20:27 +0800
committerRaymond Liu <raymond.liu@intel.com>2014-01-03 12:12:37 +0800
commit3dc379ce5aa51cc9c41f590d79c350b5dea33fc3 (patch)
tree5a99812d5e89006a3f1d4106e6eca4eb51e81433 /yarn/common
parent498a5f0a1c6e82a33c2ad8c48b68bbdb8da57a95 (diff)
downloadspark-3dc379ce5aa51cc9c41f590d79c350b5dea33fc3.tar.gz
spark-3dc379ce5aa51cc9c41f590d79c350b5dea33fc3.tar.bz2
spark-3dc379ce5aa51cc9c41f590d79c350b5dea33fc3.zip
Reorganize yarn related codes into sub projects to remove duplicate files.
Diffstat (limited to 'yarn/common')
-rw-r--r--yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala94
-rw-r--r--yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala150
-rw-r--r--yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala228
-rw-r--r--yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala43
-rw-r--r--yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala220
5 files changed, 735 insertions, 0 deletions
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
new file mode 100644
index 0000000000..f76a5ddd39
--- /dev/null
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import org.apache.spark.util.IntParam
+import collection.mutable.ArrayBuffer
+
+class ApplicationMasterArguments(val args: Array[String]) {
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var workerMemory = 1024
+ var workerCores = 1
+ var numWorkers = 2
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer = new ArrayBuffer[String]()
+
+ var args = inputArgs
+
+ while (! args.isEmpty) {
+
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
+ case ("--worker-memory") :: IntParam(value) :: tail =>
+ workerMemory = value
+ args = tail
+
+ case ("--worker-cores") :: IntParam(value) :: tail =>
+ workerCores = value
+ args = tail
+
+ case Nil =>
+ if (userJar == null || userClass == null) {
+ printUsageAndExit(1)
+ }
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ }
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println(
+ "Usage: org.apache.spark.deploy.yarn.ApplicationMaster [options] \n" +
+ "Options:\n" +
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n")
+ System.exit(exitCode)
+ }
+}
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
new file mode 100644
index 0000000000..7aac2328da
--- /dev/null
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.scheduler.{InputFormatInfo, SplitInfo}
+import org.apache.spark.util.IntParam
+import org.apache.spark.util.MemoryParam
+
+
+// TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware !
+class ClientArguments(val args: Array[String]) {
+ var addJars: String = null
+ var files: String = null
+ var archives: String = null
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var workerMemory = 1024 // MB
+ var workerCores = 1
+ var numWorkers = 2
+ var amQueue = new SparkConf().get("QUEUE", "default")
+ var amMemory: Int = 512 // MB
+ var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster"
+ var appName: String = "Spark"
+ // TODO
+ var inputFormatInfo: List[InputFormatInfo] = null
+ // TODO(harvey)
+ var priority = 0
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]()
+ val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]()
+
+ var args = inputArgs
+
+ while (!args.isEmpty) {
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--master-class") :: value :: tail =>
+ amClass = value
+ args = tail
+
+ case ("--master-memory") :: MemoryParam(value) :: tail =>
+ amMemory = value
+ args = tail
+
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
+ case ("--worker-memory") :: MemoryParam(value) :: tail =>
+ workerMemory = value
+ args = tail
+
+ case ("--worker-cores") :: IntParam(value) :: tail =>
+ workerCores = value
+ args = tail
+
+ case ("--queue") :: value :: tail =>
+ amQueue = value
+ args = tail
+
+ case ("--name") :: value :: tail =>
+ appName = value
+ args = tail
+
+ case ("--addJars") :: value :: tail =>
+ addJars = value
+ args = tail
+
+ case ("--files") :: value :: tail =>
+ files = value
+ args = tail
+
+ case ("--archives") :: value :: tail =>
+ archives = value
+ args = tail
+
+ case Nil =>
+ if (userJar == null || userClass == null) {
+ printUsageAndExit(1)
+ }
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ inputFormatInfo = inputFormatMap.values.toList
+ }
+
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println(
+ "Usage: org.apache.spark.deploy.yarn.Client [options] \n" +
+ "Options:\n" +
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
+ " --master-class CLASS_NAME Class Name for Master (Default: spark.deploy.yarn.ApplicationMaster)\n" +
+ " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
+ " --name NAME The name of your application (Default: Spark)\n" +
+ " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
+ " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" +
+ " --files files Comma separated list of files to be distributed with the job.\n" +
+ " --archives archives Comma separated list of archives to be distributed with the job."
+ )
+ System.exit(exitCode)
+ }
+
+}
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
new file mode 100644
index 0000000000..5f159b073f
--- /dev/null
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.permission.FsAction
+import org.apache.hadoop.yarn.api.records.LocalResource
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility
+import org.apache.hadoop.yarn.api.records.LocalResourceType
+import org.apache.hadoop.yarn.util.{Records, ConverterUtils}
+
+import org.apache.spark.Logging
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.LinkedHashMap
+import scala.collection.mutable.Map
+
+
+/** Client side methods to setup the Hadoop distributed cache */
+class ClientDistributedCacheManager() extends Logging {
+ private val distCacheFiles: Map[String, Tuple3[String, String, String]] =
+ LinkedHashMap[String, Tuple3[String, String, String]]()
+ private val distCacheArchives: Map[String, Tuple3[String, String, String]] =
+ LinkedHashMap[String, Tuple3[String, String, String]]()
+
+
+ /**
+ * Add a resource to the list of distributed cache resources. This list can
+ * be sent to the ApplicationMaster and possibly the workers so that it can
+ * be downloaded into the Hadoop distributed cache for use by this application.
+ * Adds the LocalResource to the localResources HashMap passed in and saves
+ * the stats of the resources to they can be sent to the workers and verified.
+ *
+ * @param fs FileSystem
+ * @param conf Configuration
+ * @param destPath path to the resource
+ * @param localResources localResource hashMap to insert the resource into
+ * @param resourceType LocalResourceType
+ * @param link link presented in the distributed cache to the destination
+ * @param statCache cache to store the file/directory stats
+ * @param appMasterOnly Whether to only add the resource to the app master
+ */
+ def addResource(
+ fs: FileSystem,
+ conf: Configuration,
+ destPath: Path,
+ localResources: HashMap[String, LocalResource],
+ resourceType: LocalResourceType,
+ link: String,
+ statCache: Map[URI, FileStatus],
+ appMasterOnly: Boolean = false) = {
+ val destStatus = fs.getFileStatus(destPath)
+ val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ amJarRsrc.setType(resourceType)
+ val visibility = getVisibility(conf, destPath.toUri(), statCache)
+ amJarRsrc.setVisibility(visibility)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(destPath))
+ amJarRsrc.setTimestamp(destStatus.getModificationTime())
+ amJarRsrc.setSize(destStatus.getLen())
+ if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name")
+ localResources(link) = amJarRsrc
+
+ if (appMasterOnly == false) {
+ val uri = destPath.toUri()
+ val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link)
+ if (resourceType == LocalResourceType.FILE) {
+ distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(),
+ destStatus.getModificationTime().toString(), visibility.name())
+ } else {
+ distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(),
+ destStatus.getModificationTime().toString(), visibility.name())
+ }
+ }
+ }
+
+ /**
+ * Adds the necessary cache file env variables to the env passed in
+ * @param env
+ */
+ def setDistFilesEnv(env: Map[String, String]) = {
+ val (keys, tupleValues) = distCacheFiles.unzip
+ val (sizes, timeStamps, visibilities) = tupleValues.unzip3
+
+ if (keys.size > 0) {
+ env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") =
+ timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_FILE_SIZES") =
+ sizes.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_VISIBILITIES") =
+ visibilities.reduceLeft[String] { (acc,n) => acc + "," + n }
+ }
+ }
+
+ /**
+ * Adds the necessary cache archive env variables to the env passed in
+ * @param env
+ */
+ def setDistArchivesEnv(env: Map[String, String]) = {
+ val (keys, tupleValues) = distCacheArchives.unzip
+ val (sizes, timeStamps, visibilities) = tupleValues.unzip3
+
+ if (keys.size > 0) {
+ env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") =
+ timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") =
+ sizes.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") =
+ visibilities.reduceLeft[String] { (acc,n) => acc + "," + n }
+ }
+ }
+
+ /**
+ * Returns the local resource visibility depending on the cache file permissions
+ * @param conf
+ * @param uri
+ * @param statCache
+ * @return LocalResourceVisibility
+ */
+ def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
+ LocalResourceVisibility = {
+ if (isPublic(conf, uri, statCache)) {
+ return LocalResourceVisibility.PUBLIC
+ }
+ return LocalResourceVisibility.PRIVATE
+ }
+
+ /**
+ * Returns a boolean to denote whether a cache file is visible to all(public)
+ * or not
+ * @param conf
+ * @param uri
+ * @param statCache
+ * @return true if the path in the uri is visible to all, false otherwise
+ */
+ def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = {
+ val fs = FileSystem.get(uri, conf)
+ val current = new Path(uri.getPath())
+ //the leaf level file should be readable by others
+ if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) {
+ return false
+ }
+ return ancestorsHaveExecutePermissions(fs, current.getParent(), statCache)
+ }
+
+ /**
+ * Returns true if all ancestors of the specified path have the 'execute'
+ * permission set for all users (i.e. that other users can traverse
+ * the directory heirarchy to the given path)
+ * @param fs
+ * @param path
+ * @param statCache
+ * @return true if all ancestors have the 'execute' permission set for all users
+ */
+ def ancestorsHaveExecutePermissions(fs: FileSystem, path: Path,
+ statCache: Map[URI, FileStatus]): Boolean = {
+ var current = path
+ while (current != null) {
+ //the subdirs in the path should have execute permissions for others
+ if (!checkPermissionOfOther(fs, current, FsAction.EXECUTE, statCache)) {
+ return false
+ }
+ current = current.getParent()
+ }
+ return true
+ }
+
+ /**
+ * Checks for a given path whether the Other permissions on it
+ * imply the permission in the passed FsAction
+ * @param fs
+ * @param path
+ * @param action
+ * @param statCache
+ * @return true if the path in the uri is visible to all, false otherwise
+ */
+ def checkPermissionOfOther(fs: FileSystem, path: Path,
+ action: FsAction, statCache: Map[URI, FileStatus]): Boolean = {
+ val status = getFileStatus(fs, path.toUri(), statCache)
+ val perms = status.getPermission()
+ val otherAction = perms.getOtherAction()
+ if (otherAction.implies(action)) {
+ return true
+ }
+ return false
+ }
+
+ /**
+ * Checks to see if the given uri exists in the cache, if it does it
+ * returns the existing FileStatus, otherwise it stats the uri, stores
+ * it in the cache, and returns the FileStatus.
+ * @param fs
+ * @param uri
+ * @param statCache
+ * @return FileStatus
+ */
+ def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = {
+ val stat = statCache.get(uri) match {
+ case Some(existstat) => existstat
+ case None =>
+ val newStat = fs.getFileStatus(new Path(uri))
+ statCache.put(uri, newStat)
+ newStat
+ }
+ return stat
+ }
+}
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
new file mode 100644
index 0000000000..2ba2366ead
--- /dev/null
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.conf.Configuration
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+class YarnSparkHadoopUtil extends SparkHadoopUtil {
+
+ // Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true.
+ override def isYarnMode(): Boolean = { true }
+
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ // Always create a new config, dont reuse yarnConf.
+ override def newConfiguration(): Configuration = new YarnConfiguration(new Configuration())
+
+ // add any user credentials to the job conf which are necessary for running on a secure Hadoop cluster
+ override def addCredentials(conf: JobConf) {
+ val jobCreds = conf.getCredentials()
+ jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials())
+ }
+}
diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
new file mode 100644
index 0000000000..2941356bc5
--- /dev/null
+++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+
+import org.scalatest.FunSuite
+import org.scalatest.mock.MockitoSugar
+import org.mockito.Mockito.when
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.permission.FsAction
+import org.apache.hadoop.yarn.api.records.LocalResource
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility
+import org.apache.hadoop.yarn.api.records.LocalResourceType
+import org.apache.hadoop.yarn.util.{Records, ConverterUtils}
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Map
+
+
+class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
+
+ class MockClientDistributedCacheManager extends ClientDistributedCacheManager {
+ override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
+ LocalResourceVisibility = {
+ return LocalResourceVisibility.PRIVATE
+ }
+ }
+
+ test("test getFileStatus empty") {
+ val distMgr = new ClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val uri = new URI("/tmp/testing")
+ when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val stat = distMgr.getFileStatus(fs, uri, statCache)
+ assert(stat.getPath() === null)
+ }
+
+ test("test getFileStatus cached") {
+ val distMgr = new ClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val uri = new URI("/tmp/testing")
+ val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus)
+ val stat = distMgr.getFileStatus(fs, uri, statCache)
+ assert(stat.getPath().toString() === "/tmp/testing")
+ }
+
+ test("test addResource") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link",
+ statCache, false)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 0)
+ assert(resource.getSize() === 0)
+ assert(resource.getType() === LocalResourceType.FILE)
+
+ val env = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env)
+ assert(env("SPARK_YARN_CACHE_FILES") === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === "0")
+ assert(env("SPARK_YARN_CACHE_FILES_FILE_SIZES") === "0")
+ assert(env("SPARK_YARN_CACHE_FILES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name())
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None)
+
+ //add another one and verify both there and order correct
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing2"))
+ val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2")
+ when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus)
+ distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2",
+ statCache, false)
+ val resource2 = localResources("link2")
+ assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource2.getResource()) === destPath2)
+ assert(resource2.getTimestamp() === 10)
+ assert(resource2.getSize() === 20)
+ assert(resource2.getType() === LocalResourceType.FILE)
+
+ val env2 = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env2)
+ val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
+ val files = env2("SPARK_YARN_CACHE_FILES").split(',')
+ val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
+ val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',')
+ assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(timestamps(0) === "0")
+ assert(sizes(0) === "0")
+ assert(visibilities(0) === LocalResourceVisibility.PRIVATE.name())
+
+ assert(files(1) === "file:/foo.invalid.com:8080/tmp/testing2#link2")
+ assert(timestamps(1) === "10")
+ assert(sizes(1) === "20")
+ assert(visibilities(1) === LocalResourceVisibility.PRIVATE.name())
+ }
+
+ test("test addResource link null") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+ intercept[Exception] {
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null,
+ statCache, false)
+ }
+ assert(localResources.get("link") === None)
+ assert(localResources.size === 0)
+ }
+
+ test("test addResource appmaster only") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ statCache, true)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 10)
+ assert(resource.getSize() === 20)
+ assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+ val env = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_FILES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None)
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None)
+ }
+
+ test("test addResource archive") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ statCache, false)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 10)
+ assert(resource.getSize() === 20)
+ assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+ val env = new HashMap[String, String]()
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env("SPARK_YARN_CACHE_ARCHIVES") === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === "10")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === "20")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name())
+
+ distMgr.setDistFilesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_FILES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None)
+ }
+
+
+}