aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/core/rest.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/core/rest.scala')
-rw-r--r--src/main/scala/xyz/driver/core/rest.scala131
1 files changed, 100 insertions, 31 deletions
diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala
index 554fe75..437df3c 100644
--- a/src/main/scala/xyz/driver/core/rest.scala
+++ b/src/main/scala/xyz/driver/core/rest.scala
@@ -3,14 +3,14 @@ package xyz.driver.core
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model._
-import akka.http.scaladsl.model.headers.RawHeader
+import akka.http.scaladsl.model.headers.{HttpChallenges, RawHeader}
+import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected
import akka.http.scaladsl.unmarshalling.Unmarshal
import akka.stream.ActorMaterializer
import com.github.swagger.akka.model._
import com.github.swagger.akka.{HasActorSystem, SwaggerHttpService}
import com.typesafe.config.Config
import io.swagger.models.Scheme
-import xyz.driver.core.auth._
import xyz.driver.core.logging.Logger
import xyz.driver.core.stats.Stats
import xyz.driver.core.time.TimeRange
@@ -18,50 +18,119 @@ import xyz.driver.core.time.provider.TimeProvider
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}
+import scalaz.OptionT
import scalaz.Scalaz.{Id => _, _}
object rest {
- object ContextHeaders {
- val AuthenticationTokenHeader = "WWW-Authenticate"
- val TrackingIdHeader = "X-Trace"
-
- object LinkerD {
- // https://linkerd.io/doc/0.7.4/linkerd/protocol-http/
- def isLinkerD(headerName: String) = headerName.startsWith("l5d-")
- }
- }
-
final case class ServiceRequestContext(
trackingId: String = generators.nextUuid().toString,
contextHeaders: Map[String, String] = Map.empty[String, String]) {
- def authToken: Option[AuthToken] = contextHeaders.get(AuthService.AuthenticationTokenHeader).map(AuthToken.apply)
+ def authToken: Option[Auth.AuthToken] =
+ contextHeaders.get(Auth.AuthProvider.AuthenticationTokenHeader).map(Auth.AuthToken.apply)
}
- import akka.http.scaladsl.server._
- import Directives._
+ object ServiceRequestContext {
- def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx))
+ object ContextHeaders {
+ val AuthenticationTokenHeader = "WWW-Authenticate"
+ val TrackingIdHeader = "X-Trace"
- def extractServiceContext(ctx: RequestContext): ServiceRequestContext =
- ServiceRequestContext(extractTrackingId(ctx), extractContextHeaders(ctx))
+ object LinkerD {
+ // https://linkerd.io/doc/0.7.4/linkerd/protocol-http/
+ def isLinkerD(headerName: String) = headerName.startsWith("l5d-")
+ }
+ }
- def extractTrackingId(ctx: RequestContext): String = {
- ctx.request.headers
- .find(_.name == ContextHeaders.TrackingIdHeader)
- .fold(java.util.UUID.randomUUID.toString)(_.value())
- }
+ import akka.http.scaladsl.server._
+ import Directives._
- def extractContextHeaders(ctx: RequestContext): Map[String, String] = {
- ctx.request.headers.filter { h =>
- h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader
- // || ContextHeaders.LinkerD.isLinkerD(h.lowercaseName)
- } map { header =>
- header.name -> header.value
- } toMap
+ def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx))
+
+ def extractServiceContext(ctx: RequestContext): ServiceRequestContext =
+ ServiceRequestContext(extractTrackingId(ctx), extractContextHeaders(ctx))
+
+ def extractTrackingId(ctx: RequestContext): String = {
+ ctx.request.headers
+ .find(_.name == ContextHeaders.TrackingIdHeader)
+ .fold(java.util.UUID.randomUUID.toString)(_.value())
+ }
+
+ def extractContextHeaders(ctx: RequestContext): Map[String, String] = {
+ ctx.request.headers.filter { h =>
+ h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader
+ // || ContextHeaders.LinkerD.isLinkerD(h.lowercaseName)
+ } map { header =>
+ header.name -> header.value
+ } toMap
+ }
}
+ object Auth {
+
+ trait Permission
+
+ trait Role {
+ val id: Id[Role]
+ val name: Name[Role]
+ val permissions: Set[Permission]
+
+ def hasPermission(permission: Permission): Boolean = permissions.contains(permission)
+ }
+
+ trait User {
+ def id: Id[User]
+ def roles: Set[Role]
+ def permissions: Set[Permission] = roles.flatMap(_.permissions)
+ }
+
+ final case class BasicUser(id: Id[User], roles: Set[Role]) extends User
+
+ final case class AuthToken(value: String)
+
+ final case class PasswordHash(value: String)
+
+ object AuthProvider {
+ val AuthenticationTokenHeader = ServiceRequestContext.ContextHeaders.AuthenticationTokenHeader
+ val SetAuthenticationTokenHeader = "set-authorization"
+ }
+
+ trait AuthProvider[U <: User] {
+
+ import akka.http.scaladsl.server._
+ import Directives._
+
+ /**
+ * Specific implementation on how to extract user from request context,
+ * can either need to do a network call to auth server or extract everything from self-contained token
+ *
+ * @param context set of request values which can be relevant to authenticate user
+ * @return authenticated user
+ */
+ protected def authenticatedUser(context: ServiceRequestContext): OptionT[Future, U]
+
+ def authorize(permissions: Permission*): Directive1[U] = {
+ ServiceRequestContext.serviceContext flatMap { ctx =>
+ onComplete(authenticatedUser(ctx).run).flatMap {
+ case Success(Some(user)) =>
+ if (permissions.forall(user.permissions.contains)) provide(user)
+ else {
+ val challenge =
+ HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}")
+ reject(AuthenticationFailedRejection(CredentialsRejected, challenge))
+ }
+
+ case Success(None) =>
+ reject(ValidationRejection(s"Wasn't able to find authenticated user for the token provided"))
+
+ case Failure(t) =>
+ reject(ValidationRejection(s"Wasn't able to verify token for authenticated user", Some(t)))
+ }
+ }
+ }
+ }
+ }
trait Service
@@ -86,7 +155,7 @@ object rest {
val requestTime = time.currentTime()
val request = requestStub
- .withHeaders(RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId))
+ .withHeaders(RawHeader(ServiceRequestContext.ContextHeaders.TrackingIdHeader, context.trackingId))
.withHeaders(context.contextHeaders.toSeq.map { h => RawHeader(h._1, h._2): HttpHeader }: _*)
log.audit(s"Sending to ${request.uri} request $request with tracking id ${context.trackingId}")