aboutsummaryrefslogtreecommitdiff
path: root/yarn/common
diff options
context:
space:
mode:
Diffstat (limited to 'yarn/common')
-rw-r--r--yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala6
-rw-r--r--yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala9
-rw-r--r--yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala4
-rw-r--r--yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala25
-rw-r--r--yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala64
5 files changed, 99 insertions, 9 deletions
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
index 4c383ab574..424b0fb093 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -29,7 +29,7 @@ class ApplicationMasterArguments(val args: Array[String]) {
var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS
parseArgs(args.toList)
-
+
private def parseArgs(inputArgs: List[String]): Unit = {
val userArgsBuffer = new ArrayBuffer[String]()
@@ -47,7 +47,7 @@ class ApplicationMasterArguments(val args: Array[String]) {
userClass = value
args = tail
- case ("--args") :: value :: tail =>
+ case ("--args" | "--arg") :: value :: tail =>
userArgsBuffer += value
args = tail
@@ -75,7 +75,7 @@ class ApplicationMasterArguments(val args: Array[String]) {
userArgs = userArgsBuffer.readOnly
}
-
+
def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
if (unknownParam != null) {
System.err.println("Unknown/unsupported param " + unknownParam)
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
index 1da0a1b675..3897b3a373 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
@@ -300,11 +300,11 @@ trait ClientBase extends Logging {
}
def userArgsToString(clientArgs: ClientArguments): String = {
- val prefix = " --args "
+ val prefix = " --arg "
val args = clientArgs.userArgs
val retval = new StringBuilder()
for (arg <- args) {
- retval.append(prefix).append(" '").append(arg).append("' ")
+ retval.append(prefix).append(" ").append(YarnSparkHadoopUtil.escapeForShell(arg))
}
retval.toString
}
@@ -386,7 +386,7 @@ trait ClientBase extends Logging {
// TODO: it might be nicer to pass these as an internal environment variable rather than
// as Java options, due to complications with string parsing of nested quotes.
for ((k, v) <- sparkConf.getAll) {
- javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\""
+ javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v")
}
if (args.amClass == classOf[ApplicationMaster].getName) {
@@ -400,7 +400,8 @@ trait ClientBase extends Logging {
// Command for the ApplicationMaster
val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++
javaOpts ++
- Seq(args.amClass, "--class", args.userClass, "--jar ", args.userJar,
+ Seq(args.amClass, "--class", YarnSparkHadoopUtil.escapeForShell(args.userClass),
+ "--jar ", YarnSparkHadoopUtil.escapeForShell(args.userJar),
userArgsToString(args),
"--executor-memory", args.executorMemory.toString,
"--executor-cores", args.executorCores.toString,
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
index 71a9e42846..312d82a649 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
@@ -68,10 +68,10 @@ trait ExecutorRunnableUtil extends Logging {
// authentication settings.
sparkConf.getAll.
filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }.
- foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" }
+ foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") }
sparkConf.getAkkaConf.
- foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" }
+ foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") }
// Commenting it out for now - so that people can refer to the properties if required. Remove
// it once cpuset version is pushed out.
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index e98308cdbd..10aef5eb24 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -148,4 +148,29 @@ object YarnSparkHadoopUtil {
}
}
+ /**
+ * Escapes a string for inclusion in a command line executed by Yarn. Yarn executes commands
+ * using `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. The
+ * argument is enclosed in single quotes and some key characters are escaped.
+ *
+ * @param arg A single argument.
+ * @return Argument quoted for execution via Yarn's generated shell script.
+ */
+ def escapeForShell(arg: String): String = {
+ if (arg != null) {
+ val escaped = new StringBuilder("'")
+ for (i <- 0 to arg.length() - 1) {
+ arg.charAt(i) match {
+ case '$' => escaped.append("\\$")
+ case '"' => escaped.append("\\\"")
+ case '\'' => escaped.append("'\\''")
+ case c => escaped.append(c)
+ }
+ }
+ escaped.append("'").toString()
+ } else {
+ arg
+ }
+ }
+
}
diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
new file mode 100644
index 0000000000..7650bd4396
--- /dev/null
+++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.{File, IOException}
+
+import com.google.common.io.{ByteStreams, Files}
+import org.scalatest.{FunSuite, Matchers}
+
+import org.apache.spark.Logging
+
+class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging {
+
+ val hasBash =
+ try {
+ val exitCode = Runtime.getRuntime().exec(Array("bash", "--version")).waitFor()
+ exitCode == 0
+ } catch {
+ case e: IOException =>
+ false
+ }
+
+ if (!hasBash) {
+ logWarning("Cannot execute bash, skipping bash tests.")
+ }
+
+ def bashTest(name: String)(fn: => Unit) =
+ if (hasBash) test(name)(fn) else ignore(name)(fn)
+
+ bashTest("shell script escaping") {
+ val scriptFile = File.createTempFile("script.", ".sh")
+ val args = Array("arg1", "${arg.2}", "\"arg3\"", "'arg4'", "$arg5", "\\arg6")
+ try {
+ val argLine = args.map(a => YarnSparkHadoopUtil.escapeForShell(a)).mkString(" ")
+ Files.write(("bash -c \"echo " + argLine + "\"").getBytes(), scriptFile)
+ scriptFile.setExecutable(true)
+
+ val proc = Runtime.getRuntime().exec(Array(scriptFile.getAbsolutePath()))
+ val out = new String(ByteStreams.toByteArray(proc.getInputStream())).trim()
+ val err = new String(ByteStreams.toByteArray(proc.getErrorStream()))
+ val exitCode = proc.waitFor()
+ exitCode should be (0)
+ out should be (args.mkString(" "))
+ } finally {
+ scriptFile.delete()
+ }
+ }
+
+}