aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala76
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala2
3 files changed, 78 insertions, 10 deletions
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 8cf4d58847..3aa3f948e8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -82,13 +82,13 @@ object SparkSubmit {
private val CLASS_NOT_FOUND_EXIT_STATUS = 101
// Exposed for testing
- private[spark] var exitFn: () => Unit = () => System.exit(1)
+ private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode)
private[spark] var printStream: PrintStream = System.err
private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str)
private[spark] def printErrorAndExit(str: String): Unit = {
printStream.println("Error: " + str)
printStream.println("Run with --help for usage help or --verbose for debug output")
- exitFn()
+ exitFn(1)
}
private[spark] def printVersionAndExit(): Unit = {
printStream.println("""Welcome to
@@ -99,7 +99,7 @@ object SparkSubmit {
/_/
""".format(SPARK_VERSION))
printStream.println("Type --help for more information.")
- exitFn()
+ exitFn(0)
}
def main(args: Array[String]): Unit = {
@@ -160,7 +160,7 @@ object SparkSubmit {
// detect exceptions with empty stack traces here, and treat them differently.
if (e.getStackTrace().length == 0) {
printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}")
- exitFn()
+ exitFn(1)
} else {
throw e
}
@@ -700,7 +700,7 @@ object SparkSubmit {
/**
* Return whether the given main class represents a sql shell.
*/
- private def isSqlShell(mainClass: String): Boolean = {
+ private[deploy] def isSqlShell(mainClass: String): Boolean = {
mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver"
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index cc6a7bd9f4..b7429a901e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -17,12 +17,15 @@
package org.apache.spark.deploy
+import java.io.{ByteArrayOutputStream, PrintStream}
+import java.lang.reflect.InvocationTargetException
import java.net.URI
import java.util.{List => JList}
import java.util.jar.JarFile
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.io.Source
import org.apache.spark.deploy.SparkSubmitAction._
import org.apache.spark.launcher.SparkSubmitArgumentsParser
@@ -412,6 +415,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
case VERSION =>
SparkSubmit.printVersionAndExit()
+ case USAGE_ERROR =>
+ printUsageAndExit(1)
+
case _ =>
throw new IllegalArgumentException(s"Unexpected argument '$opt'.")
}
@@ -449,11 +455,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
if (unknownParam != null) {
outStream.println("Unknown/unsupported param " + unknownParam)
}
- outStream.println(
+ val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse(
"""Usage: spark-submit [options] <app jar | python file> [app arguments]
|Usage: spark-submit --kill [submission ID] --master [spark://...]
- |Usage: spark-submit --status [submission ID] --master [spark://...]
- |
+ |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin)
+ outStream.println(command)
+
+ outStream.println(
+ """
|Options:
| --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local.
| --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or
@@ -525,6 +534,65 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
| delegation tokens periodically.
""".stripMargin
)
- SparkSubmit.exitFn()
+
+ if (SparkSubmit.isSqlShell(mainClass)) {
+ outStream.println("CLI options:")
+ outStream.println(getSqlShellOptions())
+ }
+
+ SparkSubmit.exitFn(exitCode)
}
+
+ /**
+ * Run the Spark SQL CLI main class with the "--help" option and catch its output. Then filter
+ * the results to remove unwanted lines.
+ *
+ * Since the CLI will call `System.exit()`, we install a security manager to prevent that call
+ * from working, and restore the original one afterwards.
+ */
+ private def getSqlShellOptions(): String = {
+ val currentOut = System.out
+ val currentErr = System.err
+ val currentSm = System.getSecurityManager()
+ try {
+ val out = new ByteArrayOutputStream()
+ val stream = new PrintStream(out)
+ System.setOut(stream)
+ System.setErr(stream)
+
+ val sm = new SecurityManager() {
+ override def checkExit(status: Int): Unit = {
+ throw new SecurityException()
+ }
+
+ override def checkPermission(perm: java.security.Permission): Unit = {}
+ }
+ System.setSecurityManager(sm)
+
+ try {
+ Class.forName(mainClass).getMethod("main", classOf[Array[String]])
+ .invoke(null, Array(HELP))
+ } catch {
+ case e: InvocationTargetException =>
+ // Ignore SecurityException, since we throw it above.
+ if (!e.getCause().isInstanceOf[SecurityException]) {
+ throw e
+ }
+ }
+
+ stream.flush()
+
+ // Get the output and discard any unnecessary lines from it.
+ Source.fromString(new String(out.toByteArray())).getLines
+ .filter { line =>
+ !line.startsWith("log4j") && !line.startsWith("usage")
+ }
+ .mkString("\n")
+ } finally {
+ System.setSecurityManager(currentSm)
+ System.setOut(currentOut)
+ System.setErr(currentErr)
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 46369457f0..46ea28d0f1 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -62,7 +62,7 @@ class SparkSubmitSuite
SparkSubmit.printStream = printStream
@volatile var exitedCleanly = false
- SparkSubmit.exitFn = () => exitedCleanly = true
+ SparkSubmit.exitFn = (_) => exitedCleanly = true
val thread = new Thread {
override def run() = try {