diff options
Diffstat (limited to 'src/main/scala/xyz/driver/core/rest.scala')
-rw-r--r-- | src/main/scala/xyz/driver/core/rest.scala | 131 |
1 files changed, 100 insertions, 31 deletions
diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index 554fe75..437df3c 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -3,14 +3,14 @@ package xyz.driver.core import akka.actor.ActorSystem import akka.http.scaladsl.Http import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers.RawHeader +import akka.http.scaladsl.model.headers.{HttpChallenges, RawHeader} +import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected import akka.http.scaladsl.unmarshalling.Unmarshal import akka.stream.ActorMaterializer import com.github.swagger.akka.model._ import com.github.swagger.akka.{HasActorSystem, SwaggerHttpService} import com.typesafe.config.Config import io.swagger.models.Scheme -import xyz.driver.core.auth._ import xyz.driver.core.logging.Logger import xyz.driver.core.stats.Stats import xyz.driver.core.time.TimeRange @@ -18,50 +18,119 @@ import xyz.driver.core.time.provider.TimeProvider import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} +import scalaz.OptionT import scalaz.Scalaz.{Id => _, _} object rest { - object ContextHeaders { - val AuthenticationTokenHeader = "WWW-Authenticate" - val TrackingIdHeader = "X-Trace" - - object LinkerD { - // https://linkerd.io/doc/0.7.4/linkerd/protocol-http/ - def isLinkerD(headerName: String) = headerName.startsWith("l5d-") - } - } - final case class ServiceRequestContext( trackingId: String = generators.nextUuid().toString, contextHeaders: Map[String, String] = Map.empty[String, String]) { - def authToken: Option[AuthToken] = contextHeaders.get(AuthService.AuthenticationTokenHeader).map(AuthToken.apply) + def authToken: Option[Auth.AuthToken] = + contextHeaders.get(Auth.AuthProvider.AuthenticationTokenHeader).map(Auth.AuthToken.apply) } - import akka.http.scaladsl.server._ - import Directives._ + object ServiceRequestContext { - def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx)) + object ContextHeaders { + val AuthenticationTokenHeader = "WWW-Authenticate" + val TrackingIdHeader = "X-Trace" - def extractServiceContext(ctx: RequestContext): ServiceRequestContext = - ServiceRequestContext(extractTrackingId(ctx), extractContextHeaders(ctx)) + object LinkerD { + // https://linkerd.io/doc/0.7.4/linkerd/protocol-http/ + def isLinkerD(headerName: String) = headerName.startsWith("l5d-") + } + } - def extractTrackingId(ctx: RequestContext): String = { - ctx.request.headers - .find(_.name == ContextHeaders.TrackingIdHeader) - .fold(java.util.UUID.randomUUID.toString)(_.value()) - } + import akka.http.scaladsl.server._ + import Directives._ - def extractContextHeaders(ctx: RequestContext): Map[String, String] = { - ctx.request.headers.filter { h => - h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader - // || ContextHeaders.LinkerD.isLinkerD(h.lowercaseName) - } map { header => - header.name -> header.value - } toMap + def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx)) + + def extractServiceContext(ctx: RequestContext): ServiceRequestContext = + ServiceRequestContext(extractTrackingId(ctx), extractContextHeaders(ctx)) + + def extractTrackingId(ctx: RequestContext): String = { + ctx.request.headers + .find(_.name == ContextHeaders.TrackingIdHeader) + .fold(java.util.UUID.randomUUID.toString)(_.value()) + } + + def extractContextHeaders(ctx: RequestContext): Map[String, String] = { + ctx.request.headers.filter { h => + h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader + // || ContextHeaders.LinkerD.isLinkerD(h.lowercaseName) + } map { header => + header.name -> header.value + } toMap + } } + object Auth { + + trait Permission + + trait Role { + val id: Id[Role] + val name: Name[Role] + val permissions: Set[Permission] + + def hasPermission(permission: Permission): Boolean = permissions.contains(permission) + } + + trait User { + def id: Id[User] + def roles: Set[Role] + def permissions: Set[Permission] = roles.flatMap(_.permissions) + } + + final case class BasicUser(id: Id[User], roles: Set[Role]) extends User + + final case class AuthToken(value: String) + + final case class PasswordHash(value: String) + + object AuthProvider { + val AuthenticationTokenHeader = ServiceRequestContext.ContextHeaders.AuthenticationTokenHeader + val SetAuthenticationTokenHeader = "set-authorization" + } + + trait AuthProvider[U <: User] { + + import akka.http.scaladsl.server._ + import Directives._ + + /** + * Specific implementation on how to extract user from request context, + * can either need to do a network call to auth server or extract everything from self-contained token + * + * @param context set of request values which can be relevant to authenticate user + * @return authenticated user + */ + protected def authenticatedUser(context: ServiceRequestContext): OptionT[Future, U] + + def authorize(permissions: Permission*): Directive1[U] = { + ServiceRequestContext.serviceContext flatMap { ctx => + onComplete(authenticatedUser(ctx).run).flatMap { + case Success(Some(user)) => + if (permissions.forall(user.permissions.contains)) provide(user) + else { + val challenge = + HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}") + reject(AuthenticationFailedRejection(CredentialsRejected, challenge)) + } + + case Success(None) => + reject(ValidationRejection(s"Wasn't able to find authenticated user for the token provided")) + + case Failure(t) => + reject(ValidationRejection(s"Wasn't able to verify token for authenticated user", Some(t))) + } + } + } + } + } trait Service @@ -86,7 +155,7 @@ object rest { val requestTime = time.currentTime() val request = requestStub - .withHeaders(RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId)) + .withHeaders(RawHeader(ServiceRequestContext.ContextHeaders.TrackingIdHeader, context.trackingId)) .withHeaders(context.contextHeaders.toSeq.map { h => RawHeader(h._1, h._2): HttpHeader }: _*) log.audit(s"Sending to ${request.uri} request $request with tracking id ${context.trackingId}") |