aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/TestUtils.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala78
-rw-r--r--core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/ui/JettyUtils.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala62
5 files changed, 115 insertions, 41 deletions
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index 109104f0a5..3f912dc191 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -200,9 +200,13 @@ private[spark] object TestUtils {
/**
* Returns the response code from an HTTP(S) URL.
*/
- def httpResponseCode(url: URL, method: String = "GET"): Int = {
+ def httpResponseCode(
+ url: URL,
+ method: String = "GET",
+ headers: Seq[(String, String)] = Nil): Int = {
val connection = url.openConnection().asInstanceOf[HttpURLConnection]
connection.setRequestMethod(method)
+ headers.foreach { case (k, v) => connection.setRequestProperty(k, v) }
// Disable cert and host name validation for HTTPS tests.
if (connection.isInstanceOf[HttpsURLConnection]) {
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
index 17bc04303f..67ccf43afa 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
@@ -18,6 +18,7 @@ package org.apache.spark.status.api.v1
import java.util.zip.ZipOutputStream
import javax.servlet.ServletContext
+import javax.servlet.http.HttpServletRequest
import javax.ws.rs._
import javax.ws.rs.core.{Context, Response}
@@ -40,7 +41,7 @@ import org.apache.spark.ui.SparkUI
* HistoryServerSuite.
*/
@Path("/v1")
-private[v1] class ApiRootResource extends UIRootFromServletContext {
+private[v1] class ApiRootResource extends ApiRequestContext {
@Path("applications")
def getApplicationList(): ApplicationListResource = {
@@ -56,21 +57,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getJobs(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllJobsResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new AllJobsResource(ui)
}
}
@Path("applications/{appId}/jobs")
def getJobs(@PathParam("appId") appId: String): AllJobsResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new AllJobsResource(ui)
}
}
@Path("applications/{appId}/jobs/{jobId: \\d+}")
def getJob(@PathParam("appId") appId: String): OneJobResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new OneJobResource(ui)
}
}
@@ -79,21 +80,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getJob(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): OneJobResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new OneJobResource(ui)
}
}
@Path("applications/{appId}/executors")
def getExecutors(@PathParam("appId") appId: String): ExecutorListResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new ExecutorListResource(ui)
}
}
@Path("applications/{appId}/allexecutors")
def getAllExecutors(@PathParam("appId") appId: String): AllExecutorListResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new AllExecutorListResource(ui)
}
}
@@ -102,7 +103,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getExecutors(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): ExecutorListResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new ExecutorListResource(ui)
}
}
@@ -111,15 +112,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getAllExecutors(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllExecutorListResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new AllExecutorListResource(ui)
}
}
-
@Path("applications/{appId}/stages")
def getStages(@PathParam("appId") appId: String): AllStagesResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new AllStagesResource(ui)
}
}
@@ -128,14 +128,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getStages(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllStagesResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new AllStagesResource(ui)
}
}
@Path("applications/{appId}/stages/{stageId: \\d+}")
def getStage(@PathParam("appId") appId: String): OneStageResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new OneStageResource(ui)
}
}
@@ -144,14 +144,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getStage(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): OneStageResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new OneStageResource(ui)
}
}
@Path("applications/{appId}/storage/rdd")
def getRdds(@PathParam("appId") appId: String): AllRDDResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new AllRDDResource(ui)
}
}
@@ -160,14 +160,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getRdds(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): AllRDDResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new AllRDDResource(ui)
}
}
@Path("applications/{appId}/storage/rdd/{rddId: \\d+}")
def getRdd(@PathParam("appId") appId: String): OneRDDResource = {
- uiRoot.withSparkUI(appId, None) { ui =>
+ withSparkUI(appId, None) { ui =>
new OneRDDResource(ui)
}
}
@@ -176,7 +176,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
def getRdd(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): OneRDDResource = {
- uiRoot.withSparkUI(appId, Some(attemptId)) { ui =>
+ withSparkUI(appId, Some(attemptId)) { ui =>
new OneRDDResource(ui)
}
}
@@ -234,19 +234,6 @@ private[spark] trait UIRoot {
.status(Response.Status.SERVICE_UNAVAILABLE)
.build()
}
-
- /**
- * Get the spark UI with the given appID, and apply a function
- * to it. If there is no such app, throw an appropriate exception
- */
- def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = {
- val appKey = attemptId.map(appId + "/" + _).getOrElse(appId)
- getSparkUI(appKey) match {
- case Some(ui) =>
- f(ui)
- case None => throw new NotFoundException("no such app: " + appId)
- }
- }
def securityManager: SecurityManager
}
@@ -263,13 +250,38 @@ private[v1] object UIRootFromServletContext {
}
}
-private[v1] trait UIRootFromServletContext {
+private[v1] trait ApiRequestContext {
+ @Context
+ protected var servletContext: ServletContext = _
+
@Context
- var servletContext: ServletContext = _
+ protected var httpRequest: HttpServletRequest = _
def uiRoot: UIRoot = UIRootFromServletContext.getUiRoot(servletContext)
+
+
+ /**
+ * Get the spark UI with the given appID, and apply a function
+ * to it. If there is no such app, throw an appropriate exception
+ */
+ def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = {
+ val appKey = attemptId.map(appId + "/" + _).getOrElse(appId)
+ uiRoot.getSparkUI(appKey) match {
+ case Some(ui) =>
+ val user = httpRequest.getRemoteUser()
+ if (!ui.securityManager.checkUIViewPermissions(user)) {
+ throw new ForbiddenException(raw"""user "$user" is not authorized""")
+ }
+ f(ui)
+ case None => throw new NotFoundException("no such app: " + appId)
+ }
+ }
+
}
+private[v1] class ForbiddenException(msg: String) extends WebApplicationException(
+ Response.status(Response.Status.FORBIDDEN).entity(msg).build())
+
private[v1] class NotFoundException(msg: String) extends WebApplicationException(
new NoSuchElementException(msg),
Response
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala
index b4a991eda3..1cd37185d6 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala
@@ -21,14 +21,14 @@ import javax.ws.rs.core.Response
import javax.ws.rs.ext.Provider
@Provider
-private[v1] class SecurityFilter extends ContainerRequestFilter with UIRootFromServletContext {
+private[v1] class SecurityFilter extends ContainerRequestFilter with ApiRequestContext {
override def filter(req: ContainerRequestContext): Unit = {
- val user = Option(req.getSecurityContext.getUserPrincipal).map { _.getName }.orNull
+ val user = httpRequest.getRemoteUser()
if (!uiRoot.securityManager.checkUIViewPermissions(user)) {
req.abortWith(
Response
.status(Response.Status.FORBIDDEN)
- .entity(raw"""user "$user"is not authorized""")
+ .entity(raw"""user "$user" is not authorized""")
.build()
)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 7909821db9..bdbdba5780 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -90,9 +90,9 @@ private[spark] object JettyUtils extends Logging {
response.setHeader("X-Frame-Options", xFrameOptionsValue)
response.getWriter.print(servletParams.extractFn(result))
} else {
- response.setStatus(HttpServletResponse.SC_UNAUTHORIZED)
+ response.setStatus(HttpServletResponse.SC_FORBIDDEN)
response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
- response.sendError(HttpServletResponse.SC_UNAUTHORIZED,
+ response.sendError(HttpServletResponse.SC_FORBIDDEN,
"User is not authorized to access this page.")
}
} catch {
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
index b2eded43ba..dcf83cb530 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
@@ -20,7 +20,8 @@ import java.io.{File, FileInputStream, FileWriter, InputStream, IOException}
import java.net.{HttpURLConnection, URL}
import java.nio.charset.StandardCharsets
import java.util.zip.ZipInputStream
-import javax.servlet.http.{HttpServletRequest, HttpServletResponse}
+import javax.servlet._
+import javax.servlet.http.{HttpServletRequest, HttpServletRequestWrapper, HttpServletResponse}
import scala.concurrent.duration._
import scala.language.postfixOps
@@ -68,11 +69,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
private var server: HistoryServer = null
private var port: Int = -1
- def init(): Unit = {
+ def init(extraConf: (String, String)*): Unit = {
val conf = new SparkConf()
.set("spark.history.fs.logDirectory", logDir)
.set("spark.history.fs.update.interval", "0")
.set("spark.testing", "true")
+ conf.setAll(extraConf)
provider = new FsHistoryProvider(conf)
provider.checkForLogs()
val securityManager = HistoryServer.createSecurityManager(conf)
@@ -566,6 +568,39 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
}
+ test("ui and api authorization checks") {
+ val appId = "app-20161115172038-0000"
+ val owner = "jose"
+ val admin = "root"
+ val other = "alice"
+
+ stop()
+ init(
+ "spark.ui.filters" -> classOf[FakeAuthFilter].getName(),
+ "spark.history.ui.acls.enable" -> "true",
+ "spark.history.ui.admin.acls" -> admin)
+
+ val tests = Seq(
+ (owner, HttpServletResponse.SC_OK),
+ (admin, HttpServletResponse.SC_OK),
+ (other, HttpServletResponse.SC_FORBIDDEN),
+ // When the remote user is null, the code behaves as if auth were disabled.
+ (null, HttpServletResponse.SC_OK))
+
+ val port = server.boundPort
+ val testUrls = Seq(
+ s"http://localhost:$port/api/v1/applications/$appId/jobs",
+ s"http://localhost:$port/history/$appId/jobs/")
+
+ tests.foreach { case (user, expectedCode) =>
+ testUrls.foreach { url =>
+ val headers = if (user != null) Seq(FakeAuthFilter.FAKE_HTTP_USER -> user) else Nil
+ val sc = TestUtils.httpResponseCode(new URL(url), headers = headers)
+ assert(sc === expectedCode, s"Unexpected status code $sc for $url (user = $user)")
+ }
+ }
+ }
+
def getContentAndCode(path: String, port: Int = port): (Int, Option[String], Option[String]) = {
HistoryServerSuite.getContentAndCode(new URL(s"http://localhost:$port/api/v1/$path"))
}
@@ -648,3 +683,26 @@ object HistoryServerSuite {
}
}
}
+
+/**
+ * A filter used for auth tests; sets the request's user to the value of the "HTTP_USER" header.
+ */
+class FakeAuthFilter extends Filter {
+
+ override def destroy(): Unit = { }
+
+ override def init(config: FilterConfig): Unit = { }
+
+ override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = {
+ val hreq = req.asInstanceOf[HttpServletRequest]
+ val wrapped = new HttpServletRequestWrapper(hreq) {
+ override def getRemoteUser(): String = hreq.getHeader(FakeAuthFilter.FAKE_HTTP_USER)
+ }
+ chain.doFilter(wrapped, res)
+ }
+
+}
+
+object FakeAuthFilter {
+ val FAKE_HTTP_USER = "HTTP_USER"
+}