aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver')
-rw-r--r--src/main/scala/xyz/driver/core/app.scala105
-rw-r--r--src/main/scala/xyz/driver/core/auth.scala1
-rw-r--r--src/main/scala/xyz/driver/core/database/database.scala17
-rw-r--r--src/main/scala/xyz/driver/core/database/package.scala16
-rw-r--r--src/main/scala/xyz/driver/core/json.scala23
-rw-r--r--src/main/scala/xyz/driver/core/rest.scala250
6 files changed, 322 insertions, 90 deletions
diff --git a/src/main/scala/xyz/driver/core/app.scala b/src/main/scala/xyz/driver/core/app.scala
index 1977d6a..eb9f7ee 100644
--- a/src/main/scala/xyz/driver/core/app.scala
+++ b/src/main/scala/xyz/driver/core/app.scala
@@ -12,13 +12,14 @@ import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.RouteResult._
import akka.http.scaladsl.server.{ExceptionHandler, Route, RouteConcatenation}
import akka.stream.ActorMaterializer
+import com.github.swagger.akka.SwaggerHttpService._
import com.typesafe.config.Config
import com.typesafe.scalalogging.Logger
import io.swagger.models.Scheme
+import io.swagger.util.Json
import org.slf4j.{LoggerFactory, MDC}
-import spray.json.DefaultJsonProtocol
import xyz.driver.core
-import xyz.driver.core.rest.{ContextHeaders, Swagger}
+import xyz.driver.core.rest._
import xyz.driver.core.stats.SystemStats
import xyz.driver.core.time.Time
import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider}
@@ -26,12 +27,16 @@ import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider}
import scala.compat.Platform.ConcurrentModificationException
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future}
+import scala.reflect.runtime.universe._
+import scala.util.control.NonFatal
+import scala.util.Try
import scalaz.Scalaz.stringInstance
import scalaz.syntax.equal._
object app {
- class DriverApp(version: String,
+ class DriverApp(appName: String,
+ version: String,
gitHash: String,
modules: Seq[Module],
time: TimeProvider = new SystemTimeProvider(),
@@ -66,7 +71,7 @@ object app {
protected def bindHttp(modules: Seq[Module]): Unit = {
val serviceTypes = modules.flatMap(_.routeTypes)
- val swaggerService = new Swagger(baseUrl, Scheme.forValue(scheme), version, actorSystem, serviceTypes, config)
+ val swaggerService = swaggerOverride(serviceTypes)
val swaggerRoutes = swaggerService.routes ~ swaggerService.swaggerUI
val versionRt = versionRoute(version, gitHash, time.currentTime())
@@ -78,14 +83,11 @@ object app {
{ ctx =>
val trackingId = rest.extractTrackingId(ctx.request)
MDC.put("trackingId", trackingId)
- MDC.put("origin", origin)
- MDC.put("xForwardedFor",
- extractHeader(ctx.request)("x-forwarded-for")
- .orElse(extractHeader(ctx.request)("x_forwarded_for"))
- .getOrElse("unknown"))
- MDC.put("remoteAddress", extractHeader(ctx.request)("remote-address").getOrElse("unknown"))
- MDC.put("userAgent", extractHeader(ctx.request)("user-agent").getOrElse("unknown"))
- MDC.put("ip", ip.toOption.map(_.getHostAddress).getOrElse("unknown"))
+
+ val updatedStacktrace = (rest.extractStacktrace(ctx.request) ++ Array(appName)).mkString("->")
+ MDC.put("stack", updatedStacktrace)
+
+ storeRequestContextToMdc(ctx.request, origin, ip)
def requestLogging: Future[Unit] = Future {
log.info(
@@ -93,7 +95,10 @@ object app {
}
val contextWithTrackingId =
- ctx.withRequest(ctx.request.addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)))
+ ctx.withRequest(
+ ctx.request
+ .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId))
+ .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace)))
handleExceptions(ExceptionHandler(exceptionHandler))({ c =>
requestLogging.flatMap { _ =>
@@ -114,6 +119,51 @@ object app {
}
}
+ private def storeRequestContextToMdc(request: HttpRequest, origin: String, ip: RemoteAddress) = {
+
+ MDC.put("origin", origin)
+ MDC.put("ip", ip.toOption.map(_.getHostAddress).getOrElse("unknown"))
+ MDC.put("remoteHost", ip.toOption.map(_.getHostName).getOrElse("unknown"))
+
+ MDC.put("xForwardedFor",
+ extractHeader(request)("x-forwarded-for")
+ .orElse(extractHeader(request)("x_forwarded_for"))
+ .getOrElse("unknown"))
+ MDC.put("remoteAddress", extractHeader(request)("remote-address").getOrElse("unknown"))
+ MDC.put("userAgent", extractHeader(request)("user-agent").getOrElse("unknown"))
+ }
+
+ protected def swaggerOverride(apiTypes: Seq[Type]) = {
+ new Swagger(baseUrl, Scheme.forValue(scheme), version, actorSystem, apiTypes, config) {
+ override def generateSwaggerJson: String = {
+ import io.swagger.models.Swagger
+
+ import scala.collection.JavaConverters._
+
+ try {
+ val swagger: Swagger = reader.read(toJavaTypeSet(apiTypes).asJava)
+
+ // Removing trailing spaces
+ swagger.setPaths(
+ swagger.getPaths.asScala
+ .map {
+ case (key, path) =>
+ key.trim -> path
+ }
+ .toMap
+ .asJava)
+
+ Json.pretty().writeValueAsString(swagger)
+ } catch {
+ case NonFatal(t) => {
+ logger.error("Issue with creating swagger.json", t)
+ throw t
+ }
+ }
+ }
+ }
+ }
+
/**
* Override me for custom exception handling
*
@@ -148,6 +198,7 @@ object app {
}
protected def versionRoute(version: String, gitHash: String, startupTime: Time): Route = {
+ import spray.json._
import DefaultJsonProtocol._
import SprayJsonSupport._
@@ -155,17 +206,27 @@ object app {
val currentTime = time.currentTime().millis
complete(
Map(
- "version" -> version,
- "gitHash" -> gitHash,
- "modules" -> modules.map(_.name).mkString(", "),
- "startupTime" -> startupTime.millis.toString,
- "serverTime" -> currentTime.toString,
- "uptime" -> (currentTime - startupTime.millis).toString
- ))
+ "version" -> version.toJson,
+ "gitHash" -> gitHash.toJson,
+ "modules" -> modules.map(_.name).toJson,
+ "dependencies" -> collectAppDependencies().toJson,
+ "startupTime" -> startupTime.millis.toString.toJson,
+ "serverTime" -> currentTime.toString.toJson,
+ "uptime" -> (currentTime - startupTime.millis).toString.toJson
+ ).toJson)
}
}
+ protected def collectAppDependencies(): Map[String, String] = {
+
+ def serviceWithLocation(serviceName: String): (String, String) =
+ serviceName -> Try(config.getString(s"services.$serviceName.baseUrl")).getOrElse("not-detected")
+
+ modules.flatMap(module => module.serviceDiscovery.getUsedServices.map(serviceWithLocation).toSeq).toMap
+ }
+
protected def healthRoute: Route = {
+ import spray.json._
import DefaultJsonProtocol._
import SprayJsonSupport._
import spray.json._
@@ -236,13 +297,13 @@ object app {
}
}
- import scala.reflect.runtime.universe._
-
trait Module {
val name: String
def route: Route
def routeTypes: Seq[Type]
+ val serviceDiscovery: ServiceDiscovery with SavingUsedServiceDiscovery = new NoServiceDiscovery()
+
def activate(): Unit = {}
def deactivate(): Unit = {}
}
diff --git a/src/main/scala/xyz/driver/core/auth.scala b/src/main/scala/xyz/driver/core/auth.scala
index f9a1a57..5dea2db 100644
--- a/src/main/scala/xyz/driver/core/auth.scala
+++ b/src/main/scala/xyz/driver/core/auth.scala
@@ -23,6 +23,7 @@ object auth {
final case class AuthToken(value: String)
final case class RefreshToken(value: String)
+ final case class PermissionsToken(value: String)
final case class PasswordHash(value: String)
diff --git a/src/main/scala/xyz/driver/core/database/database.scala b/src/main/scala/xyz/driver/core/database/database.scala
index b7a4165..8426309 100644
--- a/src/main/scala/xyz/driver/core/database/database.scala
+++ b/src/main/scala/xyz/driver/core/database/database.scala
@@ -121,6 +121,23 @@ package database {
def naturalKeyMapper[T] = MappedColumnType.base[Id[T], String](_.value, Id[T](_))
}
+ trait CreateAndDropSchema {
+ val slickDal: xyz.driver.core.database.SlickDal
+ val tables: GeneratedTables
+
+ import tables.profile.api._
+ import scala.concurrent.Await
+ import scala.concurrent.duration.Duration
+
+ def createSchema(): Unit = {
+ Await.result(slickDal.execute(tables.createNamespaceSchema >> tables.schema.create), Duration.Inf)
+ }
+
+ def dropSchema(): Unit = {
+ Await.result(slickDal.execute(tables.schema.drop >> tables.dropNamespaceSchema), Duration.Inf)
+ }
+ }
+
trait DatabaseObject extends ColumnTypes {
def createTables(): Future[Unit]
def disconnect(): Unit
diff --git a/src/main/scala/xyz/driver/core/database/package.scala b/src/main/scala/xyz/driver/core/database/package.scala
index 791a688..b39169d 100644
--- a/src/main/scala/xyz/driver/core/database/package.scala
+++ b/src/main/scala/xyz/driver/core/database/package.scala
@@ -4,13 +4,23 @@ import java.sql.{Date => SqlDate}
import java.util.Calendar
import date.{Date, Month}
-import slick.dbio.{DBIOAction, NoStream}
+import slick.dbio._
+import slick.driver.JdbcProfile
package object database {
type Schema = {
- def create: DBIOAction[Unit, NoStream, slick.dbio.Effect.Schema]
- def drop: DBIOAction[Unit, NoStream, slick.dbio.Effect.Schema]
+ def create: DBIOAction[Unit, NoStream, Effect.Schema]
+ def drop: DBIOAction[Unit, NoStream, Effect.Schema]
+ }
+
+ type GeneratedTables = {
+ // structure of Slick data model traits generated by sbt-slick-codegen
+ val profile: JdbcProfile
+ def schema: profile.SchemaDescription
+
+ def createNamespaceSchema: StreamingDBIO[Vector[Unit], Unit]
+ def dropNamespaceSchema: StreamingDBIO[Vector[Unit], Unit]
}
private[database] def sqlDateToDate(sqlDate: SqlDate): Date = {
diff --git a/src/main/scala/xyz/driver/core/json.scala b/src/main/scala/xyz/driver/core/json.scala
index 21bcad5..b203c91 100644
--- a/src/main/scala/xyz/driver/core/json.scala
+++ b/src/main/scala/xyz/driver/core/json.scala
@@ -1,21 +1,27 @@
package xyz.driver.core
+import java.util.UUID
+
+import scala.reflect.runtime.universe._
+import scala.util.Try
+
import akka.http.scaladsl.model.Uri.Path
+import akka.http.scaladsl.server._
import akka.http.scaladsl.server.PathMatcher.{Matched, Unmatched}
-import akka.http.scaladsl.server.{PathMatcher, _}
import akka.http.scaladsl.unmarshalling.Unmarshaller
-import spray.json.{DeserializationException, JsNumber, _}
+import spray.json._
import xyz.driver.core.auth.AuthCredentials
-import xyz.driver.core.time.Time
import xyz.driver.core.date.{Date, Month}
import xyz.driver.core.domain.{Email, PhoneNumber}
-
-import scala.reflect.runtime.universe._
+import xyz.driver.core.time.Time
object json {
import DefaultJsonProtocol._
- def IdInPath[T]: PathMatcher1[Id[T]] = new PathMatcher1[Id[T]] {
+ private def UuidInPath[T]: PathMatcher1[Id[T]] =
+ PathMatchers.JavaUUID.map((id: UUID) => Id[T](id.toString.toLowerCase))
+
+ def IdInPath[T]: PathMatcher1[Id[T]] = UuidInPath[T] | new PathMatcher1[Id[T]] {
def apply(path: Path) = path match {
case Path.Segment(segment, tail) => Matched(tail, Tuple1(Id[T](segment)))
case _ => Unmatched
@@ -26,8 +32,9 @@ object json {
def write(id: Id[T]) = JsString(id.value)
def read(value: JsValue) = value match {
- case JsString(id) => Id[T](id)
- case _ => throw DeserializationException("Id expects string")
+ case JsString(id) if Try(UUID.fromString(id)).isSuccess => Id[T](id.toLowerCase)
+ case JsString(id) => Id[T](id)
+ case _ => throw DeserializationException("Id expects string")
}
}
diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala
index f1eab45..f30d1ae 100644
--- a/src/main/scala/xyz/driver/core/rest.scala
+++ b/src/main/scala/xyz/driver/core/rest.scala
@@ -1,5 +1,9 @@
package xyz.driver.core
+import java.nio.file.{Files, Path}
+import java.security.spec.X509EncodedKeySpec
+import java.security.{KeyFactory, PublicKey}
+
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model._
@@ -7,8 +11,10 @@ import akka.http.scaladsl.model.headers.{HttpChallenges, RawHeader}
import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected
import akka.http.scaladsl.server.Directive0
import com.typesafe.scalalogging.Logger
-import akka.http.scaladsl.unmarshalling.Unmarshal
-import akka.http.scaladsl.unmarshalling.Unmarshaller
+import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller}
+import akka.http.scaladsl.settings.ClientConnectionSettings
+import akka.http.scaladsl.settings.ConnectionPoolSettings
+import akka.http.scaladsl.model.headers.`User-Agent`
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.Flow
import akka.util.ByteString
@@ -16,8 +22,10 @@ import com.github.swagger.akka.model._
import com.github.swagger.akka.{HasActorSystem, SwaggerHttpService}
import com.typesafe.config.Config
import io.swagger.models.Scheme
+import pdi.jwt.{Jwt, JwtAlgorithm}
import xyz.driver.core.auth._
import xyz.driver.core.time.provider.TimeProvider
+import org.slf4j.MDC
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}
@@ -33,7 +41,7 @@ package rest {
def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request))
def extractServiceContext(request: HttpRequest): ServiceRequestContext =
- ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request))
+ new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request))
def extractTrackingId(request: HttpRequest): String = {
request.headers
@@ -41,9 +49,13 @@ package rest {
.fold(java.util.UUID.randomUUID.toString)(_.value())
}
+ def extractStacktrace(request: HttpRequest): Array[String] =
+ request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->")
+
def extractContextHeaders(request: HttpRequest): Map[String, String] = {
request.headers.filter { h =>
- h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader
+ h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader ||
+ h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader
} map { header =>
if (header.name === ContextHeaders.AuthenticationTokenHeader) {
header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim
@@ -91,14 +103,59 @@ package rest {
}
}
- final case class ServiceRequestContext(trackingId: String = generators.nextUuid().toString,
- contextHeaders: Map[String, String] = Map.empty[String, String]) {
-
+ class ServiceRequestContext(val trackingId: String = generators.nextUuid().toString,
+ val contextHeaders: Map[String, String] = Map.empty[String, String]) {
def authToken: Option[AuthToken] =
contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply)
+ def permissionsToken: Option[PermissionsToken] =
+ contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply)
+
def withAuthToken(authToken: AuthToken): ServiceRequestContext =
- copy(contextHeaders = contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value))
+ new ServiceRequestContext(
+ trackingId,
+ contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value)
+ )
+
+ def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] =
+ new AuthorizedServiceRequestContext(
+ trackingId,
+ contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value),
+ user
+ )
+
+ override def hashCode(): Int =
+ Seq[Any](trackingId, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode())
+
+ override def equals(obj: Any): Boolean = obj match {
+ case ctx: ServiceRequestContext => trackingId === ctx.trackingId && contextHeaders === ctx.contextHeaders
+ case _ => false
+ }
+
+ override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)"
+ }
+
+ class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString,
+ override val contextHeaders: Map[String, String] =
+ Map.empty[String, String],
+ val authenticatedUser: U)
+ extends ServiceRequestContext {
+
+ def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] =
+ new AuthorizedServiceRequestContext[U](
+ trackingId,
+ contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value),
+ authenticatedUser)
+
+ override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode()
+
+ override def equals(obj: Any): Boolean = obj match {
+ case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser
+ case _ => false
+ }
+
+ override def toString: String =
+ s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)"
}
object ContextHeaders {
@@ -106,6 +163,7 @@ package rest {
val PermissionsTokenHeader = "Permissions"
val AuthenticationHeaderPrefix = "Bearer"
val TrackingIdHeader = "X-Trace"
+ val StacktraceHeader = "X-Stacktrace"
}
object AuthProvider {
@@ -115,18 +173,79 @@ package rest {
val SetPermissionsTokenHeader = "set-permissions"
}
- trait Authorization {
- def userHasPermission(user: User, permission: Permission)(implicit ctx: ServiceRequestContext): Future[Boolean]
+ final case class AuthorizationResult(authorized: Boolean, token: Option[PermissionsToken])
+ object AuthorizationResult {
+ val unauthorized: AuthorizationResult = AuthorizationResult(authorized = false, None)
+ }
+
+ trait Authorization[U <: User] {
+ def userHasPermissions(user: U, permissions: Seq[Permission])(
+ implicit ctx: ServiceRequestContext): Future[AuthorizationResult]
+ }
+
+ class AlwaysAllowAuthorization[U <: User](implicit execution: ExecutionContext) extends Authorization[U] {
+ override def userHasPermissions(user: U, permissions: Seq[Permission])(
+ implicit ctx: ServiceRequestContext): Future[AuthorizationResult] =
+ Future.successful(AuthorizationResult(authorized = true, ctx.permissionsToken))
+ }
+
+ class CachedTokenAuthorization[U <: User](publicKey: => PublicKey, issuer: String) extends Authorization[U] {
+ override def userHasPermissions(user: U, permissions: Seq[Permission])(
+ implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = {
+ import spray.json._
+
+ def extractPermissionsFromTokenJSON(tokenObject: JsObject): Option[Map[String, Boolean]] =
+ tokenObject.fields.get("permissions").collect {
+ case JsObject(fields) =>
+ fields.collect {
+ case (key, JsBoolean(value)) => key -> value
+ }
+ }
+
+ val result = for {
+ token <- ctx.permissionsToken
+ jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption
+ jwtJson = jwt.parseJson.asJsObject
+
+ // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None
+ _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(())
+ _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(())
+
+ permissionsMap <- extractPermissionsFromTokenJSON(jwtJson)
+
+ authorized = permissions.forall(p => permissionsMap.get(p.toString).contains(true))
+ } yield AuthorizationResult(authorized, Some(token))
+
+ Future.successful(result.getOrElse(AuthorizationResult.unauthorized))
+ }
+ }
+
+ object CachedTokenAuthorization {
+ def apply[U <: User](publicKeyFile: Path, issuer: String): CachedTokenAuthorization[U] = {
+ lazy val publicKey: PublicKey = {
+ val publicKeyBase64Encoded = Files.readAllBytes(publicKeyFile)
+ val publicKeyBase64Decoded = java.util.Base64.getDecoder.decode(publicKeyBase64Encoded)
+ val spec = new X509EncodedKeySpec(publicKeyBase64Decoded)
+ KeyFactory.getInstance("RSA").generatePublic(spec)
+ }
+ new CachedTokenAuthorization[U](publicKey, issuer)
+ }
}
- class AlwaysAllowAuthorization extends Authorization {
- override def userHasPermission(user: User, permission: Permission)(
- implicit ctx: ServiceRequestContext): Future[Boolean] = {
- Future.successful(true)
+ class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext)
+ extends Authorization[U] {
+
+ override def userHasPermissions(user: U, permissions: Seq[Permission])(
+ implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = {
+ authorizations.toList.foldLeftM[Future, AuthorizationResult](AuthorizationResult.unauthorized) {
+ (authResult, authorization) =>
+ if (authResult.authorized) Future.successful(authResult)
+ else authorization.userHasPermissions(user, permissions)
+ }
}
}
- abstract class AuthProvider[U <: User](val authorization: Authorization, log: Logger)(
+ abstract class AuthProvider[U <: User](val authorization: Authorization[U], log: Logger)(
implicit execution: ExecutionContext) {
import akka.http.scaladsl.server._
@@ -142,43 +261,30 @@ package rest {
def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U]
/**
- * Specific implementation can verify session expiration and single sign out
- * to verify if session is still valid
- */
- def isSessionValid(user: U)(implicit ctx: ServiceRequestContext): Future[Boolean]
-
- /**
* Verifies if request is authenticated and authorized to have `permissions`
*/
- def authorize(permissions: Permission*): Directive1[U] = {
+ def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = {
serviceContext flatMap { ctx =>
- onComplete(authenticatedUser(ctx).run flatMap { userOption =>
- userOption.traverseM[Future, (U, Boolean)] { user =>
- isSessionValid(user)(ctx).flatMap { sessionValid =>
- if (sessionValid) {
- permissions.toList
- .traverse[Future, Boolean](authorization.userHasPermission(user, _)(ctx))
- .map(results => Option(user -> results.forall(identity)))
- } else {
- Future.successful(Option.empty[(U, Boolean)])
- }
- }
- }
- }).flatMap {
- case Success(Some((user, authorizationResult))) =>
- if (authorizationResult) provide(user)
- else {
- val challenge =
- HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}")
- log.warn(s"User $user does not have the required permissions: ${permissions.mkString(", ")}")
- reject(AuthenticationFailedRejection(CredentialsRejected, challenge))
- }
-
+ onComplete {
+ (for {
+ authToken <- OptionT.optionT(Future.successful(ctx.authToken))
+ user <- authenticatedUser(ctx)
+ authCtx = ctx.withAuthenticatedUser(authToken, user)
+ authorizationResult <- authorization.userHasPermissions(user, permissions)(authCtx).toOptionT
+ cachedPermissionsAuthCtx = authorizationResult.token.fold(authCtx)(authCtx.withPermissionsToken)
+ } yield (cachedPermissionsAuthCtx, authorizationResult.authorized)).run
+ } flatMap {
+ case Success(Some((authCtx, true))) => provide(authCtx)
+ case Success(Some((authCtx, false))) =>
+ val challenge =
+ HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}")
+ log.warn(
+ s"User ${authCtx.authenticatedUser} does not have the required permissions: ${permissions.mkString(", ")}")
+ reject(AuthenticationFailedRejection(CredentialsRejected, challenge))
case Success(None) =>
log.warn(
s"Wasn't able to find authenticated user for the token provided to verify ${permissions.mkString(", ")}")
reject(ValidationRejection(s"Wasn't able to find authenticated user for the token provided"))
-
case Failure(t) =>
log.warn(s"Wasn't able to verify token for authenticated user to verify ${permissions.mkString(", ")}", t)
reject(ValidationRejection(s"Wasn't able to verify token for authenticated user", Some(t)))
@@ -193,7 +299,6 @@ package rest {
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import spray.json._
- import DefaultJsonProtocol._
protected implicit val exec: ExecutionContext
protected implicit val materializer: ActorMaterializer
@@ -217,10 +322,7 @@ package rest {
protected def jsonEntity(json: JsValue): RequestEntity =
HttpEntity(ContentTypes.`application/json`, json.compactPrint)
- protected def get(baseUri: Uri, path: String) =
- HttpRequest(HttpMethods.GET, endpointUri(baseUri, path))
-
- protected def get(baseUri: Uri, path: String, query: Map[String, String]) =
+ protected def get(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) =
HttpRequest(HttpMethods.GET, endpointUri(baseUri, path, query))
protected def post(baseUri: Uri, path: String, httpEntity: RequestEntity) =
@@ -235,8 +337,8 @@ package rest {
protected def endpointUri(baseUri: Uri, path: String) =
baseUri.withPath(Uri.Path(path))
- protected def endpointUri(baseUri: Uri, path: String, query: Map[String, String]) =
- baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query))
+ protected def endpointUri(baseUri: Uri, path: String, query: Seq[(String, String)]) =
+ baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query: _*))
}
trait ServiceTransport {
@@ -251,7 +353,26 @@ package rest {
def discover[T <: Service](serviceName: Name[Service]): T
}
- class HttpRestServiceTransport(actorSystem: ActorSystem,
+ class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery {
+
+ def discover[T <: Service](serviceName: Name[Service]): T =
+ throw new IllegalArgumentException(s"Service with name $serviceName is unknown")
+ }
+
+ trait SavingUsedServiceDiscovery {
+
+ private val usedServices = new scala.collection.mutable.HashSet[String]()
+
+ def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized {
+ usedServices += serviceName.value
+ }
+
+ def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet }
+ }
+
+ class HttpRestServiceTransport(applicationName: Name[App],
+ applicationVersion: String,
+ actorSystem: ActorSystem,
executionContext: ExecutionContext,
log: Logger,
time: TimeProvider)
@@ -260,19 +381,34 @@ package rest {
protected implicit val materializer = ActorMaterializer()(actorSystem)
protected implicit val execution = executionContext
+ private val client = Http()(actorSystem)
+
+ private val clientConnectionSettings: ClientConnectionSettings =
+ ClientConnectionSettings(actorSystem).withUserAgentHeader(
+ Option(`User-Agent`(applicationName.value + "/" + applicationVersion)))
+
+ private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem)
+ .withConnectionSettings(clientConnectionSettings)
+
def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = {
val requestTime = time.currentTime()
val request = requestStub
- .withHeaders(RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId))
- .withHeaders(context.contextHeaders.toSeq.map { h =>
- RawHeader(h._1, h._2): HttpHeader
+ .withHeaders(context.contextHeaders.toSeq.map {
+ case (ContextHeaders.TrackingIdHeader, headerValue) =>
+ RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId)
+ case (ContextHeaders.StacktraceHeader, headerValue) =>
+ RawHeader(ContextHeaders.StacktraceHeader,
+ Option(MDC.get("stack"))
+ .orElse(context.contextHeaders.get(ContextHeaders.StacktraceHeader))
+ .getOrElse(""))
+ case (header, headerValue) => RawHeader(header, headerValue)
}: _*)
log.info(s"Sending request to ${request.method} ${request.uri}")
- val response = Http()(actorSystem).singleRequest(request)(materializer)
+ val response = client.singleRequest(request, settings = connectionPoolSettings)(materializer)
response.onComplete {
case Success(r) =>