diff options
Diffstat (limited to 'src/main/scala/xyz/driver/core/rest/package.scala')
-rw-r--r-- | src/main/scala/xyz/driver/core/rest/package.scala | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala index d4d01df..7d67138 100644 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -14,6 +14,7 @@ import akka.util.ByteString import scalaz.Scalaz.{intInstance, stringInstance} import scalaz.syntax.equal._ import scalaz.{Functor, OptionT} +import xyz.driver.core.rest.auth.AuthProvider import xyz.driver.tracing.TracingDirectives import scala.concurrent.Future @@ -90,13 +91,6 @@ object `package` { val SpanHeaderName: String = TracingDirectives.SpanHeaderName } - object AuthProvider { - val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader - val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader - val SetAuthenticationTokenHeader: String = "set-authorization" - val SetPermissionsTokenHeader: String = "set-permissions" - } - val AllowedHeaders: Seq[String] = Seq( "Origin", @@ -131,8 +125,18 @@ object `package` { originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) def serviceContext: Directive1[ServiceRequestContext] = { + def fixAuthorizationHeader(headers: Seq[HttpHeader]): collection.immutable.Seq[HttpHeader] = { + headers.map({ header => + if (header.name === ContextHeaders.AuthenticationTokenHeader && !header.value.startsWith( + ContextHeaders.AuthenticationHeaderPrefix)) { + Authorization(OAuth2BearerToken(header.value)) + } else header + })(collection.breakOut) + } extractClientIP flatMap { remoteAddress => - extract(ctx => extractServiceContext(ctx.request, remoteAddress)) + mapRequest(req => req.withHeaders(fixAuthorizationHeader(req.headers))) tflatMap { _ => + extract(ctx => extractServiceContext(ctx.request, remoteAddress)) + } } } |