From 2c524be93052fc6d57359a00fd60d957099885c6 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Wed, 27 Sep 2017 22:20:53 -0700 Subject: Refactor rest package into separate files --- src/main/scala/xyz/driver/core/rest.scala | 586 --------------------- .../rest/AuthorizedServiceRequestContext.scala | 27 + .../scala/xyz/driver/core/rest/HttpClient.scala | 9 + .../core/rest/HttpRestServiceTransport.scala | 73 +++ .../scala/xyz/driver/core/rest/Implicits.scala | 17 + .../xyz/driver/core/rest/NoServiceDiscovery.scala | 9 + .../scala/xyz/driver/core/rest/Pagination.scala | 3 + .../xyz/driver/core/rest/PooledHttpClient.scala | 67 +++ .../scala/xyz/driver/core/rest/RestService.scala | 54 ++ .../core/rest/SavingUsedServiceDiscovery.scala | 14 + src/main/scala/xyz/driver/core/rest/Service.scala | 3 + .../xyz/driver/core/rest/ServiceDiscovery.scala | 8 + .../driver/core/rest/ServiceRequestContext.scala | 39 ++ .../xyz/driver/core/rest/ServiceTransport.scala | 13 + .../driver/core/rest/SingleRequestHttpClient.scala | 29 + src/main/scala/xyz/driver/core/rest/Swagger.scala | 52 ++ .../core/rest/auth/AlwaysAllowAuthorization.scala | 14 + .../xyz/driver/core/rest/auth/AuthProvider.scala | 64 +++ .../xyz/driver/core/rest/auth/Authorization.scala | 11 + .../core/rest/auth/AuthorizationResult.scala | 22 + .../core/rest/auth/CachedTokenAuthorization.scala | 55 ++ .../core/rest/auth/ChainedAuthorization.scala | 27 + src/main/scala/xyz/driver/core/rest/package.scala | 92 ++++ 23 files changed, 702 insertions(+), 586 deletions(-) delete mode 100644 src/main/scala/xyz/driver/core/rest.scala create mode 100644 src/main/scala/xyz/driver/core/rest/AuthorizedServiceRequestContext.scala create mode 100644 src/main/scala/xyz/driver/core/rest/HttpClient.scala create mode 100644 src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala create mode 100644 src/main/scala/xyz/driver/core/rest/Implicits.scala create mode 100644 src/main/scala/xyz/driver/core/rest/NoServiceDiscovery.scala create mode 100644 src/main/scala/xyz/driver/core/rest/Pagination.scala create mode 100644 src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala create mode 100644 src/main/scala/xyz/driver/core/rest/RestService.scala create mode 100644 src/main/scala/xyz/driver/core/rest/SavingUsedServiceDiscovery.scala create mode 100644 src/main/scala/xyz/driver/core/rest/Service.scala create mode 100644 src/main/scala/xyz/driver/core/rest/ServiceDiscovery.scala create mode 100644 src/main/scala/xyz/driver/core/rest/ServiceRequestContext.scala create mode 100644 src/main/scala/xyz/driver/core/rest/ServiceTransport.scala create mode 100644 src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala create mode 100644 src/main/scala/xyz/driver/core/rest/Swagger.scala create mode 100644 src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala create mode 100644 src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala create mode 100644 src/main/scala/xyz/driver/core/rest/auth/Authorization.scala create mode 100644 src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala create mode 100644 src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala create mode 100644 src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala create mode 100644 src/main/scala/xyz/driver/core/rest/package.scala (limited to 'src/main/scala/xyz/driver/core') diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala deleted file mode 100644 index ee564de..0000000 --- a/src/main/scala/xyz/driver/core/rest.scala +++ /dev/null @@ -1,586 +0,0 @@ -package xyz.driver.core.rest - -import java.nio.file.{Files, Path} -import java.security.spec.X509EncodedKeySpec -import java.security.{KeyFactory, PublicKey} - -import akka.actor.ActorSystem -import akka.http.scaladsl.Http -import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} -import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers.{HttpChallenges, RawHeader, `User-Agent`} -import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected -import akka.http.scaladsl.server.Route -import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} -import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller} -import akka.stream._ -import akka.stream.scaladsl.{Flow, Keep, Sink, Source} -import akka.util.ByteString -import com.github.swagger.akka.model._ -import com.github.swagger.akka.{HasActorSystem, SwaggerHttpService} -import com.typesafe.config.Config -import com.typesafe.scalalogging.Logger -import io.swagger.models.Scheme -import org.slf4j.MDC -import pdi.jwt.{Jwt, JwtAlgorithm} -import xyz.driver.core.auth._ -import xyz.driver.core.time.provider.TimeProvider -import xyz.driver.core.{Name, generators} - -import scala.concurrent.duration._ -import scala.concurrent.{ExecutionContext, Future, Promise} -import scala.util.{Failure, Success} -import scalaz.Scalaz.{futureInstance, intInstance, listInstance, mapEqual, mapMonoid, stringInstance} -import scalaz.syntax.equal._ -import scalaz.syntax.semigroup._ -import scalaz.syntax.std.boolean._ -import scalaz.syntax.traverse._ -import scalaz.{Functor, ListT, OptionT, Semigroup} - -object `package` { - import akka.http.scaladsl.server._ - import Directives._ - - def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) - - def extractServiceContext(request: HttpRequest): ServiceRequestContext = - new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) - - def extractTrackingId(request: HttpRequest): String = { - request.headers - .find(_.name == ContextHeaders.TrackingIdHeader) - .fold(java.util.UUID.randomUUID.toString)(_.value()) - } - - def extractStacktrace(request: HttpRequest): Array[String] = - request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") - - def extractContextHeaders(request: HttpRequest): Map[String, String] = { - request.headers.filter { h => - h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || - h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || - h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName - } map { header => - if (header.name === ContextHeaders.AuthenticationTokenHeader) { - header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim - } else { - header.name -> header.value - } - } toMap - } - - private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { - @annotation.tailrec - def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { - val index = byteString.indexOf('/', from) - if (index === -1) descIndices.reverse - else { - val (init, tail) = byteString.splitAt(index) - if ((init endsWith "<") && (tail startsWith "/sc")) { - dirtyIndices(index + 1, index :: descIndices) - } else { - dirtyIndices(index + 1, descIndices) - } - } - } - - val indices = dirtyIndices(0, Nil) - - indices.headOption.fold(byteString) { head => - val builder = ByteString.newBuilder - builder ++= byteString.take(head) - - (indices :+ byteString.length).sliding(2).foreach { - case Seq(start, end) => - builder += ' ' - builder ++= byteString.slice(start, end) - case Seq(_) => // Should not match; sliding on at least 2 elements - assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") - } - builder.result - } - } - - val sanitizeRequestEntity: Directive0 = { - mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) - } -} - -object Implicits { - implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { - def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( - implicit F: Functor[Future], - em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { - optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) - } - } -} - -class ServiceRequestContext(val trackingId: String = generators.nextUuid().toString, - val contextHeaders: Map[String, String] = Map.empty[String, String]) { - def authToken: Option[AuthToken] = - contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) - - def permissionsToken: Option[PermissionsToken] = - contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) - - def withAuthToken(authToken: AuthToken): ServiceRequestContext = - new ServiceRequestContext( - trackingId, - contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) - ) - - def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = - new AuthorizedServiceRequestContext( - trackingId, - contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), - user - ) - - override def hashCode(): Int = - Seq[Any](trackingId, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) - - override def equals(obj: Any): Boolean = obj match { - case ctx: ServiceRequestContext => trackingId === ctx.trackingId && contextHeaders === ctx.contextHeaders - case _ => false - } - - override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" -} - -class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString, - override val contextHeaders: Map[String, String] = - Map.empty[String, String], - val authenticatedUser: U) - extends ServiceRequestContext { - - def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = - new AuthorizedServiceRequestContext[U]( - trackingId, - contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), - authenticatedUser) - - override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() - - override def equals(obj: Any): Boolean = obj match { - case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser - case _ => false - } - - override def toString: String = - s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" -} - -object ContextHeaders { - import xyz.driver.tracing.TracingDirectives - val AuthenticationTokenHeader = "Authorization" - val PermissionsTokenHeader = "Permissions" - val AuthenticationHeaderPrefix = "Bearer" - val TrackingIdHeader = "X-Trace" - val StacktraceHeader = "X-Stacktrace" - val TraceHeaderName = TracingDirectives.TraceHeaderName - val SpanHeaderName = TracingDirectives.SpanHeaderName -} - -object AuthProvider { - val AuthenticationTokenHeader = ContextHeaders.AuthenticationTokenHeader - val PermissionsTokenHeader = ContextHeaders.PermissionsTokenHeader - val SetAuthenticationTokenHeader = "set-authorization" - val SetPermissionsTokenHeader = "set-permissions" -} - -final case class Pagination(pageSize: Int, pageNumber: Int) - -final case class AuthorizationResult(authorized: Map[Permission, Boolean], token: Option[PermissionsToken]) -object AuthorizationResult { - val unauthorized: AuthorizationResult = AuthorizationResult(authorized = Map.empty, None) - - implicit val authorizationSemigroup: Semigroup[AuthorizationResult] = new Semigroup[AuthorizationResult] { - private implicit val authorizedBooleanSemigroup = Semigroup.instance[Boolean](_ || _) - private implicit val permissionsTokenSemigroup = - Semigroup.instance[Option[PermissionsToken]]((a, b) => b.orElse(a)) - - override def append(a: AuthorizationResult, b: => AuthorizationResult): AuthorizationResult = { - AuthorizationResult(a.authorized |+| b.authorized, a.token |+| b.token) - } - } -} - -trait Authorization[U <: User] { - def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] -} - -class AlwaysAllowAuthorization[U <: User] extends Authorization[U] { - override def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - val permissionsMap = permissions.map(_ -> true).toMap - Future.successful(AuthorizationResult(authorized = permissionsMap, ctx.permissionsToken)) - } -} - -class CachedTokenAuthorization[U <: User](publicKey: => PublicKey, issuer: String) extends Authorization[U] { - override def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - import spray.json._ - - def extractPermissionsFromTokenJSON(tokenObject: JsObject): Option[Map[String, Boolean]] = - tokenObject.fields.get("permissions").collect { - case JsObject(fields) => - fields.collect { - case (key, JsBoolean(value)) => key -> value - } - } - - val result = for { - token <- ctx.permissionsToken - jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption - jwtJson = jwt.parseJson.asJsObject - - // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None - _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(()) - _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(()) - - permissionsMap <- extractPermissionsFromTokenJSON(jwtJson) - - authorized = permissions.map(p => p -> permissionsMap.getOrElse(p.toString, false)).toMap - } yield AuthorizationResult(authorized, Some(token)) - - Future.successful(result.getOrElse(AuthorizationResult.unauthorized)) - } -} - -object CachedTokenAuthorization { - def apply[U <: User](publicKeyFile: Path, issuer: String): CachedTokenAuthorization[U] = { - lazy val publicKey: PublicKey = { - val publicKeyBase64Encoded = new String(Files.readAllBytes(publicKeyFile)).trim - val publicKeyBase64Decoded = java.util.Base64.getDecoder.decode(publicKeyBase64Encoded) - val spec = new X509EncodedKeySpec(publicKeyBase64Decoded) - KeyFactory.getInstance("RSA").generatePublic(spec) - } - new CachedTokenAuthorization[U](publicKey, issuer) - } -} - -class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext) - extends Authorization[U] { - - override def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - def allAuthorized(permissionsMap: Map[Permission, Boolean]): Boolean = - permissions.forall(permissionsMap.getOrElse(_, false)) - - authorizations.toList.foldLeftM[Future, AuthorizationResult](AuthorizationResult.unauthorized) { - (authResult, authorization) => - if (allAuthorized(authResult.authorized)) Future.successful(authResult) - else { - authorization.userHasPermissions(user, permissions).map(authResult |+| _) - } - } - } -} - -abstract class AuthProvider[U <: User](val authorization: Authorization[U], log: Logger)( - implicit execution: ExecutionContext) { - - import akka.http.scaladsl.server._ - import Directives._ - - /** - * Specific implementation on how to extract user from request context, - * can either need to do a network call to auth server or extract everything from self-contained token - * - * @param ctx set of request values which can be relevant to authenticate user - * @return authenticated user - */ - def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U] - - /** - * Verifies if request is authenticated and authorized to have `permissions` - */ - def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { - serviceContext flatMap { ctx => - onComplete { - (for { - authToken <- OptionT.optionT(Future.successful(ctx.authToken)) - user <- authenticatedUser(ctx) - authCtx = ctx.withAuthenticatedUser(authToken, user) - authorizationResult <- authorization.userHasPermissions(user, permissions)(authCtx).toOptionT - - cachedPermissionsAuthCtx = authorizationResult.token.fold(authCtx)(authCtx.withPermissionsToken) - allAuthorized = permissions.forall(authorizationResult.authorized.getOrElse(_, false)) - } yield (cachedPermissionsAuthCtx, allAuthorized)).run - } flatMap { - case Success(Some((authCtx, true))) => provide(authCtx) - case Success(Some((authCtx, false))) => - val challenge = - HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}") - log.warn( - s"User ${authCtx.authenticatedUser} does not have the required permissions: ${permissions.mkString(", ")}") - reject(AuthenticationFailedRejection(CredentialsRejected, challenge)) - case Success(None) => - log.warn( - s"Wasn't able to find authenticated user for the token provided to verify ${permissions.mkString(", ")}") - reject(ValidationRejection(s"Wasn't able to find authenticated user for the token provided")) - case Failure(t) => - log.warn(s"Wasn't able to verify token for authenticated user to verify ${permissions.mkString(", ")}", t) - reject(ValidationRejection(s"Wasn't able to verify token for authenticated user", Some(t))) - } - } - } -} - -trait Service - -trait RestService extends Service { - - import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ - import spray.json._ - - protected implicit val exec: ExecutionContext - protected implicit val materializer: Materializer - - implicit class ResponseEntityFoldable(entity: Unmarshal[ResponseEntity]) { - def fold[T](default: => T)(implicit um: Unmarshaller[ResponseEntity, T]): Future[T] = - if (entity.value.isKnownEmpty()) Future.successful[T](default) else entity.to[T] - } - - protected def unitResponse(request: Future[Unmarshal[ResponseEntity]]): OptionT[Future, Unit] = - OptionT[Future, Unit](request.flatMap(_.to[String]).map(_ => Option(()))) - - protected def optionalResponse[T](request: Future[Unmarshal[ResponseEntity]])( - implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] = - OptionT[Future, T](request.flatMap(_.fold(Option.empty[T]))) - - protected def listResponse[T](request: Future[Unmarshal[ResponseEntity]])( - implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] = - ListT[Future, T](request.flatMap(_.fold(List.empty[T]))) - - protected def jsonEntity(json: JsValue): RequestEntity = - HttpEntity(ContentTypes.`application/json`, json.compactPrint) - - protected def get(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = - HttpRequest(HttpMethods.GET, endpointUri(baseUri, path, query)) - - protected def post(baseUri: Uri, path: String, httpEntity: RequestEntity) = - HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = httpEntity) - - protected def postJson(baseUri: Uri, path: String, json: JsValue) = - HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = jsonEntity(json)) - - protected def delete(baseUri: Uri, path: String) = - HttpRequest(HttpMethods.DELETE, endpointUri(baseUri, path)) - - protected def endpointUri(baseUri: Uri, path: String): Uri = - baseUri.withPath(Uri.Path(path)) - - protected def endpointUri(baseUri: Uri, path: String, query: Seq[(String, String)]): Uri = - baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query: _*)) -} - -trait ServiceTransport { - - def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] - - def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] -} - -trait ServiceDiscovery { - - def discover[T <: Service](serviceName: Name[Service]): T -} - -class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery { - - def discover[T <: Service](serviceName: Name[Service]): T = - throw new IllegalArgumentException(s"Service with name $serviceName is unknown") -} - -trait SavingUsedServiceDiscovery { - - private val usedServices = new scala.collection.mutable.HashSet[String]() - - def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized { - usedServices += serviceName.value - } - - def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet } -} - -class HttpRestServiceTransport(applicationName: Name[App], - applicationVersion: String, - actorSystem: ActorSystem, - executionContext: ExecutionContext, - log: Logger, - time: TimeProvider) - extends ServiceTransport { - - protected implicit val execution: ExecutionContext = executionContext - - protected val httpClient: HttpClient = new SingleRequestHttpClient(applicationName, applicationVersion, actorSystem) - - def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { - - val requestTime = time.currentTime() - - val request = requestStub - .withHeaders(context.contextHeaders.toSeq.map { - case (ContextHeaders.TrackingIdHeader, _) => - RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId) - case (ContextHeaders.StacktraceHeader, _) => - RawHeader(ContextHeaders.StacktraceHeader, - Option(MDC.get("stack")) - .orElse(context.contextHeaders.get(ContextHeaders.StacktraceHeader)) - .getOrElse("")) - case (header, headerValue) => RawHeader(header, headerValue) - }: _*) - - log.info(s"Sending request to ${request.method} ${request.uri}") - - val response = httpClient.makeRequest(request) - - response.onComplete { - case Success(r) => - val responseLatency = requestTime.durationTo(time.currentTime()) - log.info(s"Response from ${request.uri} to request $requestStub is successful in $responseLatency ms: $r") - - case Failure(t: Throwable) => - val responseLatency = requestTime.durationTo(time.currentTime()) - log.info(s"Failed to receive response from ${request.method} ${request.uri} in $responseLatency ms", t) - log.warn(s"Failed to receive response from ${request.method} ${request.uri} in $responseLatency ms", t) - }(executionContext) - - response - } - - def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = { - - sendRequestGetResponse(context)(requestStub) map { response => - if (response.status == StatusCodes.NotFound) { - Unmarshal(HttpEntity.Empty: ResponseEntity) - } else if (response.status.isFailure()) { - throw new Exception(s"Http status is failure ${response.status} for ${requestStub.method} ${requestStub.uri}") - } else { - Unmarshal(response.entity) - } - } - } -} - -trait HttpClient { - def makeRequest(request: HttpRequest): Future[HttpResponse] -} - -class SingleRequestHttpClient(applicationName: Name[App], applicationVersion: String, actorSystem: ActorSystem) - extends HttpClient { - - protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) - private val client = Http()(actorSystem) - - private val clientConnectionSettings: ClientConnectionSettings = - ClientConnectionSettings(actorSystem).withUserAgentHeader( - Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) - - private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) - .withConnectionSettings(clientConnectionSettings) - - def makeRequest(request: HttpRequest): Future[HttpResponse] = { - client.singleRequest(request, settings = connectionPoolSettings)(materializer) - } -} - -class PooledHttpClient( - baseUri: Uri, - applicationName: Name[App], - applicationVersion: String, - requestRateLimit: Int = 64, - requestQueueSize: Int = 1024)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) - extends HttpClient { - - private val host = baseUri.authority.host.toString() - private val port = baseUri.effectivePort - private val scheme = baseUri.scheme - - protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) - - private val clientConnectionSettings: ClientConnectionSettings = - ClientConnectionSettings(actorSystem).withUserAgentHeader( - Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) - - private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) - .withConnectionSettings(clientConnectionSettings) - - private val pool = if (scheme.equalsIgnoreCase("https")) { - Http().cachedHostConnectionPoolHttps[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) - } else { - Http().cachedHostConnectionPool[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) - } - - private val queue = Source - .queue[(HttpRequest, Promise[HttpResponse])](requestQueueSize, OverflowStrategy.dropNew) - .via(pool) - .throttle(requestRateLimit, 1.second, maximumBurst = requestRateLimit, ThrottleMode.shaping) - .toMat(Sink.foreach({ - case ((Success(resp), p)) => p.success(resp) - case ((Failure(e), p)) => p.failure(e) - }))(Keep.left) - .run - - def makeRequest(request: HttpRequest): Future[HttpResponse] = { - val responsePromise = Promise[HttpResponse]() - - queue.offer(request -> responsePromise).flatMap { - case QueueOfferResult.Enqueued => - responsePromise.future - case QueueOfferResult.Dropped => - Future.failed(new Exception(s"Request queue to the host $host is overflown")) - case QueueOfferResult.Failure(ex) => - Future.failed(ex) - case QueueOfferResult.QueueClosed => - Future.failed(new Exception("Queue was closed (pool shut down) while running the request")) - } - } -} - -import scala.reflect.runtime.universe._ - -class Swagger(override val host: String, - override val scheme: Scheme, - version: String, - override val actorSystem: ActorSystem, - override val apiTypes: Seq[Type], - val config: Config) - extends SwaggerHttpService with HasActorSystem { - - val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) - - override val basePath: String = config.getString("swagger.basePath") - override val apiDocsPath: String = config.getString("swagger.docsPath") - - override val info = Info( - config.getString("swagger.apiInfo.description"), - version, - config.getString("swagger.apiInfo.title"), - config.getString("swagger.apiInfo.termsOfServiceUrl"), - contact = Some( - Contact( - config.getString("swagger.apiInfo.contact.name"), - config.getString("swagger.apiInfo.contact.url"), - config.getString("swagger.apiInfo.contact.email") - )), - license = Some( - License( - config.getString("swagger.apiInfo.license"), - config.getString("swagger.apiInfo.licenseUrl") - )), - vendorExtensions = Map.empty[String, AnyRef] - ) - - def swaggerUI: Route = get { - pathPrefix("") { - pathEndOrSingleSlash { - getFromResource("swagger-ui/index.html") - } - } ~ getFromResourceDirectory("swagger-ui") - } -} diff --git a/src/main/scala/xyz/driver/core/rest/AuthorizedServiceRequestContext.scala b/src/main/scala/xyz/driver/core/rest/AuthorizedServiceRequestContext.scala new file mode 100644 index 0000000..1cf62c9 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/AuthorizedServiceRequestContext.scala @@ -0,0 +1,27 @@ +package xyz.driver.core.rest + +import xyz.driver.core.auth.{PermissionsToken, User} +import xyz.driver.core.generators + +class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString, + override val contextHeaders: Map[String, String] = + Map.empty[String, String], + val authenticatedUser: U) + extends ServiceRequestContext { + + def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext[U]( + trackingId, + contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), + authenticatedUser) + + override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() + + override def equals(obj: Any): Boolean = obj match { + case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser + case _ => false + } + + override def toString: String = + s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" +} diff --git a/src/main/scala/xyz/driver/core/rest/HttpClient.scala b/src/main/scala/xyz/driver/core/rest/HttpClient.scala new file mode 100644 index 0000000..6f6fda0 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/HttpClient.scala @@ -0,0 +1,9 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.{HttpRequest, HttpResponse} + +import scala.concurrent.Future + +trait HttpClient { + def makeRequest(request: HttpRequest): Future[HttpResponse] +} diff --git a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala b/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala new file mode 100644 index 0000000..1e95811 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala @@ -0,0 +1,73 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.RawHeader +import akka.http.scaladsl.unmarshalling.Unmarshal +import com.typesafe.scalalogging.Logger +import org.slf4j.MDC +import xyz.driver.core.Name +import xyz.driver.core.time.provider.TimeProvider + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} + +class HttpRestServiceTransport(applicationName: Name[App], + applicationVersion: String, + actorSystem: ActorSystem, + executionContext: ExecutionContext, + log: Logger, + time: TimeProvider) + extends ServiceTransport { + + protected implicit val execution: ExecutionContext = executionContext + + protected val httpClient: HttpClient = new SingleRequestHttpClient(applicationName, applicationVersion, actorSystem) + + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { + + val requestTime = time.currentTime() + + val request = requestStub + .withHeaders(context.contextHeaders.toSeq.map { + case (ContextHeaders.TrackingIdHeader, _) => + RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId) + case (ContextHeaders.StacktraceHeader, _) => + RawHeader(ContextHeaders.StacktraceHeader, + Option(MDC.get("stack")) + .orElse(context.contextHeaders.get(ContextHeaders.StacktraceHeader)) + .getOrElse("")) + case (header, headerValue) => RawHeader(header, headerValue) + }: _*) + + log.info(s"Sending request to ${request.method} ${request.uri}") + + val response = httpClient.makeRequest(request) + + response.onComplete { + case Success(r) => + val responseLatency = requestTime.durationTo(time.currentTime()) + log.info(s"Response from ${request.uri} to request $requestStub is successful in $responseLatency ms: $r") + + case Failure(t: Throwable) => + val responseLatency = requestTime.durationTo(time.currentTime()) + log.info(s"Failed to receive response from ${request.method} ${request.uri} in $responseLatency ms", t) + log.warn(s"Failed to receive response from ${request.method} ${request.uri} in $responseLatency ms", t) + }(executionContext) + + response + } + + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = { + + sendRequestGetResponse(context)(requestStub) map { response => + if (response.status == StatusCodes.NotFound) { + Unmarshal(HttpEntity.Empty: ResponseEntity) + } else if (response.status.isFailure()) { + throw new Exception(s"Http status is failure ${response.status} for ${requestStub.method} ${requestStub.uri}") + } else { + Unmarshal(response.entity) + } + } + } +} diff --git a/src/main/scala/xyz/driver/core/rest/Implicits.scala b/src/main/scala/xyz/driver/core/rest/Implicits.scala new file mode 100644 index 0000000..8b499dd --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/Implicits.scala @@ -0,0 +1,17 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} +import akka.http.scaladsl.model.StatusCodes + +import scala.concurrent.Future +import scalaz.{Functor, OptionT} + +object Implicits { + implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { + def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( + implicit F: Functor[Future], + em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { + optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) + } + } +} diff --git a/src/main/scala/xyz/driver/core/rest/NoServiceDiscovery.scala b/src/main/scala/xyz/driver/core/rest/NoServiceDiscovery.scala new file mode 100644 index 0000000..9d2febd --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/NoServiceDiscovery.scala @@ -0,0 +1,9 @@ +package xyz.driver.core.rest + +import xyz.driver.core.Name + +class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery { + + def discover[T <: Service](serviceName: Name[Service]): T = + throw new IllegalArgumentException(s"Service with name $serviceName is unknown") +} diff --git a/src/main/scala/xyz/driver/core/rest/Pagination.scala b/src/main/scala/xyz/driver/core/rest/Pagination.scala new file mode 100644 index 0000000..f97660f --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/Pagination.scala @@ -0,0 +1,3 @@ +package xyz.driver.core.rest + +final case class Pagination(pageSize: Int, pageNumber: Int) diff --git a/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala b/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala new file mode 100644 index 0000000..2c9dcac --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala @@ -0,0 +1,67 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.headers.`User-Agent` +import akka.http.scaladsl.model.{HttpRequest, HttpResponse, Uri} +import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} +import akka.stream.scaladsl.{Keep, Sink, Source} +import akka.stream.{ActorMaterializer, OverflowStrategy, QueueOfferResult, ThrottleMode} +import xyz.driver.core.Name + +import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.concurrent.duration._ +import scala.util.{Failure, Success} + +class PooledHttpClient( + baseUri: Uri, + applicationName: Name[App], + applicationVersion: String, + requestRateLimit: Int = 64, + requestQueueSize: Int = 1024)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) + extends HttpClient { + + private val host = baseUri.authority.host.toString() + private val port = baseUri.effectivePort + private val scheme = baseUri.scheme + + protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) + + private val clientConnectionSettings: ClientConnectionSettings = + ClientConnectionSettings(actorSystem).withUserAgentHeader( + Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) + + private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) + .withConnectionSettings(clientConnectionSettings) + + private val pool = if (scheme.equalsIgnoreCase("https")) { + Http().cachedHostConnectionPoolHttps[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) + } else { + Http().cachedHostConnectionPool[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) + } + + private val queue = Source + .queue[(HttpRequest, Promise[HttpResponse])](requestQueueSize, OverflowStrategy.dropNew) + .via(pool) + .throttle(requestRateLimit, 1.second, maximumBurst = requestRateLimit, ThrottleMode.shaping) + .toMat(Sink.foreach({ + case ((Success(resp), p)) => p.success(resp) + case ((Failure(e), p)) => p.failure(e) + }))(Keep.left) + .run + + def makeRequest(request: HttpRequest): Future[HttpResponse] = { + val responsePromise = Promise[HttpResponse]() + + queue.offer(request -> responsePromise).flatMap { + case QueueOfferResult.Enqueued => + responsePromise.future + case QueueOfferResult.Dropped => + Future.failed(new Exception(s"Request queue to the host $host is overflown")) + case QueueOfferResult.Failure(ex) => + Future.failed(ex) + case QueueOfferResult.QueueClosed => + Future.failed(new Exception("Queue was closed (pool shut down) while running the request")) + } + } +} diff --git a/src/main/scala/xyz/driver/core/rest/RestService.scala b/src/main/scala/xyz/driver/core/rest/RestService.scala new file mode 100644 index 0000000..aed8d28 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/RestService.scala @@ -0,0 +1,54 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model._ +import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller} +import akka.stream.Materializer + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.{ListT, OptionT} + +trait RestService extends Service { + + import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ + import spray.json._ + + protected implicit val exec: ExecutionContext + protected implicit val materializer: Materializer + + implicit class ResponseEntityFoldable(entity: Unmarshal[ResponseEntity]) { + def fold[T](default: => T)(implicit um: Unmarshaller[ResponseEntity, T]): Future[T] = + if (entity.value.isKnownEmpty()) Future.successful[T](default) else entity.to[T] + } + + protected def unitResponse(request: Future[Unmarshal[ResponseEntity]]): OptionT[Future, Unit] = + OptionT[Future, Unit](request.flatMap(_.to[String]).map(_ => Option(()))) + + protected def optionalResponse[T](request: Future[Unmarshal[ResponseEntity]])( + implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] = + OptionT[Future, T](request.flatMap(_.fold(Option.empty[T]))) + + protected def listResponse[T](request: Future[Unmarshal[ResponseEntity]])( + implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] = + ListT[Future, T](request.flatMap(_.fold(List.empty[T]))) + + protected def jsonEntity(json: JsValue): RequestEntity = + HttpEntity(ContentTypes.`application/json`, json.compactPrint) + + protected def get(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = + HttpRequest(HttpMethods.GET, endpointUri(baseUri, path, query)) + + protected def post(baseUri: Uri, path: String, httpEntity: RequestEntity) = + HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = httpEntity) + + protected def postJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = jsonEntity(json)) + + protected def delete(baseUri: Uri, path: String) = + HttpRequest(HttpMethods.DELETE, endpointUri(baseUri, path)) + + protected def endpointUri(baseUri: Uri, path: String): Uri = + baseUri.withPath(Uri.Path(path)) + + protected def endpointUri(baseUri: Uri, path: String, query: Seq[(String, String)]): Uri = + baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query: _*)) +} diff --git a/src/main/scala/xyz/driver/core/rest/SavingUsedServiceDiscovery.scala b/src/main/scala/xyz/driver/core/rest/SavingUsedServiceDiscovery.scala new file mode 100644 index 0000000..8018bdf --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/SavingUsedServiceDiscovery.scala @@ -0,0 +1,14 @@ +package xyz.driver.core.rest + +import xyz.driver.core.Name + +trait SavingUsedServiceDiscovery { + + private val usedServices = new scala.collection.mutable.HashSet[String]() + + def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized { + usedServices += serviceName.value + } + + def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet } +} diff --git a/src/main/scala/xyz/driver/core/rest/Service.scala b/src/main/scala/xyz/driver/core/rest/Service.scala new file mode 100644 index 0000000..8216ab7 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/Service.scala @@ -0,0 +1,3 @@ +package xyz.driver.core.rest + +trait Service diff --git a/src/main/scala/xyz/driver/core/rest/ServiceDiscovery.scala b/src/main/scala/xyz/driver/core/rest/ServiceDiscovery.scala new file mode 100644 index 0000000..f0f3b5b --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/ServiceDiscovery.scala @@ -0,0 +1,8 @@ +package xyz.driver.core.rest + +import xyz.driver.core.Name + +trait ServiceDiscovery { + + def discover[T <: Service](serviceName: Name[Service]): T +} diff --git a/src/main/scala/xyz/driver/core/rest/ServiceRequestContext.scala b/src/main/scala/xyz/driver/core/rest/ServiceRequestContext.scala new file mode 100644 index 0000000..5235da6 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/ServiceRequestContext.scala @@ -0,0 +1,39 @@ +package xyz.driver.core.rest + +import xyz.driver.core.auth.{AuthToken, PermissionsToken, User} +import xyz.driver.core.generators + +import scalaz.Scalaz.{mapEqual, stringInstance} +import scalaz.syntax.equal._ + +class ServiceRequestContext(val trackingId: String = generators.nextUuid().toString, + val contextHeaders: Map[String, String] = Map.empty[String, String]) { + def authToken: Option[AuthToken] = + contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) + + def permissionsToken: Option[PermissionsToken] = + contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) + + def withAuthToken(authToken: AuthToken): ServiceRequestContext = + new ServiceRequestContext( + trackingId, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) + ) + + def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext( + trackingId, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), + user + ) + + override def hashCode(): Int = + Seq[Any](trackingId, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) + + override def equals(obj: Any): Boolean = obj match { + case ctx: ServiceRequestContext => trackingId === ctx.trackingId && contextHeaders === ctx.contextHeaders + case _ => false + } + + override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" +} diff --git a/src/main/scala/xyz/driver/core/rest/ServiceTransport.scala b/src/main/scala/xyz/driver/core/rest/ServiceTransport.scala new file mode 100644 index 0000000..9c0c429 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/ServiceTransport.scala @@ -0,0 +1,13 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity} +import akka.http.scaladsl.unmarshalling.Unmarshal + +import scala.concurrent.Future + +trait ServiceTransport { + + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] + + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] +} diff --git a/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala b/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala new file mode 100644 index 0000000..4f1f7d0 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala @@ -0,0 +1,29 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.headers.`User-Agent` +import akka.http.scaladsl.model.{HttpRequest, HttpResponse} +import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} +import akka.stream.ActorMaterializer +import xyz.driver.core.Name + +import scala.concurrent.Future + +class SingleRequestHttpClient(applicationName: Name[App], applicationVersion: String, actorSystem: ActorSystem) + extends HttpClient { + + protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) + private val client = Http()(actorSystem) + + private val clientConnectionSettings: ClientConnectionSettings = + ClientConnectionSettings(actorSystem).withUserAgentHeader( + Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) + + private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) + .withConnectionSettings(clientConnectionSettings) + + def makeRequest(request: HttpRequest): Future[HttpResponse] = { + client.singleRequest(request, settings = connectionPoolSettings)(materializer) + } +} diff --git a/src/main/scala/xyz/driver/core/rest/Swagger.scala b/src/main/scala/xyz/driver/core/rest/Swagger.scala new file mode 100644 index 0000000..e0efeaf --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/Swagger.scala @@ -0,0 +1,52 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.server.Route +import akka.stream._ +import com.github.swagger.akka.model._ +import com.github.swagger.akka.{HasActorSystem, SwaggerHttpService} +import com.typesafe.config.Config +import io.swagger.models.Scheme + +import scala.reflect.runtime.universe._ + +class Swagger(override val host: String, + override val scheme: Scheme, + version: String, + override val actorSystem: ActorSystem, + override val apiTypes: Seq[Type], + val config: Config) + extends SwaggerHttpService with HasActorSystem { + + val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) + + override val basePath: String = config.getString("swagger.basePath") + override val apiDocsPath: String = config.getString("swagger.docsPath") + + override val info = Info( + config.getString("swagger.apiInfo.description"), + version, + config.getString("swagger.apiInfo.title"), + config.getString("swagger.apiInfo.termsOfServiceUrl"), + contact = Some( + Contact( + config.getString("swagger.apiInfo.contact.name"), + config.getString("swagger.apiInfo.contact.url"), + config.getString("swagger.apiInfo.contact.email") + )), + license = Some( + License( + config.getString("swagger.apiInfo.license"), + config.getString("swagger.apiInfo.licenseUrl") + )), + vendorExtensions = Map.empty[String, AnyRef] + ) + + def swaggerUI: Route = get { + pathPrefix("") { + pathEndOrSingleSlash { + getFromResource("swagger-ui/index.html") + } + } ~ getFromResourceDirectory("swagger-ui") + } +} diff --git a/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala b/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala new file mode 100644 index 0000000..ea29a6a --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala @@ -0,0 +1,14 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future + +class AlwaysAllowAuthorization[U <: User] extends Authorization[U] { + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + val permissionsMap = permissions.map(_ -> true).toMap + Future.successful(AuthorizationResult(authorized = permissionsMap, ctx.permissionsToken)) + } +} diff --git a/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala b/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala new file mode 100644 index 0000000..35b65f7 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala @@ -0,0 +1,64 @@ +package xyz.driver.core.rest.auth + +import akka.http.scaladsl.model.headers.HttpChallenges +import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected +import com.typesafe.scalalogging.Logger +import xyz.driver.core._ +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.{AuthorizedServiceRequestContext, ServiceRequestContext, serviceContext} + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} + +import scalaz.Scalaz.futureInstance +import scalaz.OptionT + +abstract class AuthProvider[U <: User](val authorization: Authorization[U], log: Logger)( + implicit execution: ExecutionContext) { + + import akka.http.scaladsl.server._ + import Directives._ + + /** + * Specific implementation on how to extract user from request context, + * can either need to do a network call to auth server or extract everything from self-contained token + * + * @param ctx set of request values which can be relevant to authenticate user + * @return authenticated user + */ + def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U] + + /** + * Verifies if request is authenticated and authorized to have `permissions` + */ + def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { + serviceContext flatMap { ctx => + onComplete { + (for { + authToken <- OptionT.optionT(Future.successful(ctx.authToken)) + user <- authenticatedUser(ctx) + authCtx = ctx.withAuthenticatedUser(authToken, user) + authorizationResult <- authorization.userHasPermissions(user, permissions)(authCtx).toOptionT + + cachedPermissionsAuthCtx = authorizationResult.token.fold(authCtx)(authCtx.withPermissionsToken) + allAuthorized = permissions.forall(authorizationResult.authorized.getOrElse(_, false)) + } yield (cachedPermissionsAuthCtx, allAuthorized)).run + } flatMap { + case Success(Some((authCtx, true))) => provide(authCtx) + case Success(Some((authCtx, false))) => + val challenge = + HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}") + log.warn( + s"User ${authCtx.authenticatedUser} does not have the required permissions: ${permissions.mkString(", ")}") + reject(AuthenticationFailedRejection(CredentialsRejected, challenge)) + case Success(None) => + log.warn( + s"Wasn't able to find authenticated user for the token provided to verify ${permissions.mkString(", ")}") + reject(ValidationRejection(s"Wasn't able to find authenticated user for the token provided")) + case Failure(t) => + log.warn(s"Wasn't able to verify token for authenticated user to verify ${permissions.mkString(", ")}", t) + reject(ValidationRejection(s"Wasn't able to verify token for authenticated user", Some(t))) + } + } + } +} diff --git a/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala b/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala new file mode 100644 index 0000000..87d0614 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala @@ -0,0 +1,11 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future + +trait Authorization[U <: User] { + def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] +} diff --git a/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala b/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala new file mode 100644 index 0000000..efe28c9 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala @@ -0,0 +1,22 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, PermissionsToken} + +import scalaz.Scalaz.mapMonoid +import scalaz.Semigroup +import scalaz.syntax.semigroup._ + +final case class AuthorizationResult(authorized: Map[Permission, Boolean], token: Option[PermissionsToken]) +object AuthorizationResult { + val unauthorized: AuthorizationResult = AuthorizationResult(authorized = Map.empty, None) + + implicit val authorizationSemigroup: Semigroup[AuthorizationResult] = new Semigroup[AuthorizationResult] { + private implicit val authorizedBooleanSemigroup = Semigroup.instance[Boolean](_ || _) + private implicit val permissionsTokenSemigroup = + Semigroup.instance[Option[PermissionsToken]]((a, b) => b.orElse(a)) + + override def append(a: AuthorizationResult, b: => AuthorizationResult): AuthorizationResult = { + AuthorizationResult(a.authorized |+| b.authorized, a.token |+| b.token) + } + } +} diff --git a/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala b/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala new file mode 100644 index 0000000..4f4c811 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala @@ -0,0 +1,55 @@ +package xyz.driver.core.rest.auth + +import java.nio.file.{Files, Path} +import java.security.{KeyFactory, PublicKey} +import java.security.spec.X509EncodedKeySpec + +import pdi.jwt.{Jwt, JwtAlgorithm} +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future +import scalaz.syntax.std.boolean._ + +class CachedTokenAuthorization[U <: User](publicKey: => PublicKey, issuer: String) extends Authorization[U] { + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + import spray.json._ + + def extractPermissionsFromTokenJSON(tokenObject: JsObject): Option[Map[String, Boolean]] = + tokenObject.fields.get("permissions").collect { + case JsObject(fields) => + fields.collect { + case (key, JsBoolean(value)) => key -> value + } + } + + val result = for { + token <- ctx.permissionsToken + jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption + jwtJson = jwt.parseJson.asJsObject + + // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None + _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(()) + _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(()) + + permissionsMap <- extractPermissionsFromTokenJSON(jwtJson) + + authorized = permissions.map(p => p -> permissionsMap.getOrElse(p.toString, false)).toMap + } yield AuthorizationResult(authorized, Some(token)) + + Future.successful(result.getOrElse(AuthorizationResult.unauthorized)) + } +} + +object CachedTokenAuthorization { + def apply[U <: User](publicKeyFile: Path, issuer: String): CachedTokenAuthorization[U] = { + lazy val publicKey: PublicKey = { + val publicKeyBase64Encoded = new String(Files.readAllBytes(publicKeyFile)).trim + val publicKeyBase64Decoded = java.util.Base64.getDecoder.decode(publicKeyBase64Encoded) + val spec = new X509EncodedKeySpec(publicKeyBase64Decoded) + KeyFactory.getInstance("RSA").generatePublic(spec) + } + new CachedTokenAuthorization[U](publicKey, issuer) + } +} diff --git a/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala b/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala new file mode 100644 index 0000000..f5eb402 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala @@ -0,0 +1,27 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.Scalaz.{futureInstance, listInstance} +import scalaz.syntax.semigroup._ +import scalaz.syntax.traverse._ + +class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext) + extends Authorization[U] { + + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + def allAuthorized(permissionsMap: Map[Permission, Boolean]): Boolean = + permissions.forall(permissionsMap.getOrElse(_, false)) + + authorizations.toList.foldLeftM[Future, AuthorizationResult](AuthorizationResult.unauthorized) { + (authResult, authorization) => + if (allAuthorized(authResult.authorized)) Future.successful(authResult) + else { + authorization.userHasPermissions(user, permissions).map(authResult |+| _) + } + } + } +} diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala new file mode 100644 index 0000000..4c8e13c --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -0,0 +1,92 @@ +package xyz.driver.core + +import akka.http.scaladsl.model.HttpRequest +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import akka.stream.scaladsl.Flow +import akka.util.ByteString + +import scalaz.Scalaz.{intInstance, stringInstance} +import scalaz.syntax.equal._ + +package object rest { + object ContextHeaders { + val AuthenticationTokenHeader = "Authorization" + val PermissionsTokenHeader = "Permissions" + val AuthenticationHeaderPrefix = "Bearer" + val TrackingIdHeader = "X-Trace" + val StacktraceHeader = "X-Stacktrace" + val TracingHeader = trace.TracingHeaderKey + } + + object AuthProvider { + val AuthenticationTokenHeader = ContextHeaders.AuthenticationTokenHeader + val PermissionsTokenHeader = ContextHeaders.PermissionsTokenHeader + val SetAuthenticationTokenHeader = "set-authorization" + val SetPermissionsTokenHeader = "set-permissions" + } + + def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) + + def extractServiceContext(request: HttpRequest): ServiceRequestContext = + new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) + + def extractTrackingId(request: HttpRequest): String = { + request.headers + .find(_.name == ContextHeaders.TrackingIdHeader) + .fold(java.util.UUID.randomUUID.toString)(_.value()) + } + + def extractStacktrace(request: HttpRequest): Array[String] = + request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") + + def extractContextHeaders(request: HttpRequest): Map[String, String] = { + request.headers.filter { h => + h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || + h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || + h.name === ContextHeaders.TracingHeader + } map { header => + if (header.name === ContextHeaders.AuthenticationTokenHeader) { + header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim + } else { + header.name -> header.value + } + } toMap + } + + private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { + @annotation.tailrec + def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { + val index = byteString.indexOf('/', from) + if (index === -1) descIndices.reverse + else { + val (init, tail) = byteString.splitAt(index) + if ((init endsWith "<") && (tail startsWith "/sc")) { + dirtyIndices(index + 1, index :: descIndices) + } else { + dirtyIndices(index + 1, descIndices) + } + } + } + + val indices = dirtyIndices(0, Nil) + + indices.headOption.fold(byteString) { head => + val builder = ByteString.newBuilder + builder ++= byteString.take(head) + + (indices :+ byteString.length).sliding(2).foreach { + case Seq(start, end) => + builder += ' ' + builder ++= byteString.slice(start, end) + case Seq(_) => // Should not match; sliding on at least 2 elements + assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") + } + builder.result + } + } + + val sanitizeRequestEntity: Directive0 = { + mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) + } +} -- cgit v1.2.3 From 6e9f40e4cacedfab43c92248d425866d73ea700e Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Mon, 16 Oct 2017 09:46:20 -0700 Subject: Split up app package into separate files --- src/main/scala/xyz/driver/core/app.scala | 426 --------------------- src/main/scala/xyz/driver/core/app/DriverApp.scala | 361 +++++++++++++++++ src/main/scala/xyz/driver/core/app/Module.scala | 53 +++ src/main/scala/xyz/driver/core/rest/package.scala | 49 ++- 4 files changed, 452 insertions(+), 437 deletions(-) delete mode 100644 src/main/scala/xyz/driver/core/app.scala create mode 100644 src/main/scala/xyz/driver/core/app/DriverApp.scala create mode 100644 src/main/scala/xyz/driver/core/app/Module.scala (limited to 'src/main/scala/xyz/driver/core') diff --git a/src/main/scala/xyz/driver/core/app.scala b/src/main/scala/xyz/driver/core/app.scala deleted file mode 100644 index 19eef52..0000000 --- a/src/main/scala/xyz/driver/core/app.scala +++ /dev/null @@ -1,426 +0,0 @@ -package xyz.driver.core - -import java.sql.SQLException - -import akka.actor.ActorSystem -import akka.http.scaladsl.Http -import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model.StatusCodes._ -import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server.RouteResult._ -import akka.http.scaladsl.server._ -import akka.stream.ActorMaterializer -import com.github.swagger.akka.SwaggerHttpService._ -import com.typesafe.config.Config -import com.typesafe.scalalogging.Logger -import io.swagger.models.Scheme -import io.swagger.util.Json -import org.slf4j.{LoggerFactory, MDC} -import xyz.driver.core -import xyz.driver.core.rest._ -import xyz.driver.core.stats.SystemStats -import xyz.driver.core.time.Time -import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider} -import xyz.driver.tracing.TracingDirectives._ -import xyz.driver.tracing._ - -import scala.compat.Platform.ConcurrentModificationException -import scala.concurrent.duration._ -import scala.concurrent.{Await, ExecutionContext, Future} -import scala.reflect.runtime.universe._ -import scala.util.Try -import scala.util.control.NonFatal -import scalaz.Scalaz.stringInstance -import scalaz.syntax.equal._ - -object app { - - class DriverApp(appName: String, - version: String, - gitHash: String, - modules: Seq[Module], - time: TimeProvider = new SystemTimeProvider(), - log: Logger = Logger(LoggerFactory.getLogger(classOf[DriverApp])), - config: Config = core.config.loadDefaultConfig, - interface: String = "::0", - baseUrl: String = "localhost:8080", - scheme: String = "http", - port: Int = 8080, - tracer: Tracer = NoTracer)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) { - - implicit private lazy val materializer = ActorMaterializer()(actorSystem) - private lazy val http = Http()(actorSystem) - val appEnvironment = config.getString("application.environment") - - def run(): Unit = { - activateServices(modules) - scheduleServicesDeactivation(modules) - bindHttp(modules) - Console.print(s"${this.getClass.getName} App is started\n") - } - - def stop(): Unit = { - http.shutdownAllConnectionPools().onComplete { _ => - Await.result(tracer.close(), 15.seconds) // flush out any remaining traces from the buffer - val _ = actorSystem.terminate() - val terminated = Await.result(actorSystem.whenTerminated, 30.seconds) - val addressTerminated = if (terminated.addressTerminated) "is" else "is not" - Console.print(s"${this.getClass.getName} App $addressTerminated stopped ") - } - } - - private def extractHeader(request: HttpRequest)(headerName: String): Option[String] = - request.headers.find(_.name().toLowerCase === headerName).map(_.value()) - - private val allowedHeaders = - Seq( - "Origin", - "X-Requested-With", - "Content-Type", - "Content-Length", - "Accept", - "X-Trace", - "Access-Control-Allow-Methods", - "Access-Control-Allow-Origin", - "Access-Control-Allow-Headers", - "Server", - "Date", - ContextHeaders.TrackingIdHeader, - ContextHeaders.TraceHeaderName, - ContextHeaders.SpanHeaderName, - ContextHeaders.StacktraceHeader, - ContextHeaders.AuthenticationTokenHeader, - "X-Frame-Options", - "X-Content-Type-Options", - "Strict-Transport-Security", - AuthProvider.SetAuthenticationTokenHeader, - AuthProvider.SetPermissionsTokenHeader - ) - - private def allowOrigin(originHeader: Option[Origin]) = - `Access-Control-Allow-Origin`( - originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) - - protected implicit def rejectionHandler = - RejectionHandler - .newBuilder() - .handleAll[MethodRejection] { rejections => - val methods = rejections map (_.supported) - lazy val names = methods map (_.name) mkString ", " - - options { ctx => - optionalHeaderValueByType[Origin](()) { originHeader => - respondWithHeaders(List[HttpHeader]( - Allow(methods), - `Access-Control-Allow-Methods`(methods), - allowOrigin(originHeader), - `Access-Control-Allow-Headers`(allowedHeaders: _*), - `Access-Control-Expose-Headers`(allowedHeaders: _*) - )) { - complete(s"Supported methods: $names.") - } - }(ctx) - } ~ - complete(MethodNotAllowed -> s"HTTP method not allowed, supported methods: $names!") - } - .result() - - protected def bindHttp(modules: Seq[Module]): Unit = { - val serviceTypes = modules.flatMap(_.routeTypes) - val swaggerService = swaggerOverride(serviceTypes) - val swaggerRoutes = swaggerService.routes ~ swaggerService.swaggerUI - val versionRt = versionRoute(version, gitHash, time.currentTime()) - - val _ = Future { - http.bindAndHandle( - route2HandlerFlow(extractHost { origin => - trace(tracer) { - extractClientIP { - ip => - optionalHeaderValueByType[Origin](()) { - originHeader => - { - ctx => - val trackingId = rest.extractTrackingId(ctx.request) - MDC.put("trackingId", trackingId) - - val updatedStacktrace = - (rest.extractStacktrace(ctx.request) ++ Array(appName)).mkString("->") - MDC.put("stack", updatedStacktrace) - - storeRequestContextToMdc(ctx.request, origin, ip) - - def requestLogging: Future[Unit] = Future { - log.info( - s"""Received request {"method":"${ctx.request.method.value}","url": "${ctx.request.uri}"}""") - } - - val contextWithTrackingId = - ctx.withRequest( - ctx.request - .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) - .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) - - handleExceptions(ExceptionHandler(exceptionHandler))({ - c => - requestLogging.flatMap { _ => - val trackingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) - - val responseHeaders = List[HttpHeader]( - trackingHeader, - allowOrigin(originHeader), - `Access-Control-Allow-Headers`(allowedHeaders: _*), - `Access-Control-Expose-Headers`(allowedHeaders: _*) - ) - - respondWithHeaders(responseHeaders) { - modules.map(_.route).foldLeft(versionRt ~ healthRoute ~ swaggerRoutes)(_ ~ _) - }(c) - } - })(contextWithTrackingId) - } - } - } - } - }), - interface, - port - )(materializer) - } - } - - private def storeRequestContextToMdc(request: HttpRequest, origin: String, ip: RemoteAddress) = { - - MDC.put("origin", origin) - MDC.put("ip", ip.toOption.map(_.getHostAddress).getOrElse("unknown")) - MDC.put("remoteHost", ip.toOption.map(_.getHostName).getOrElse("unknown")) - - MDC.put("xForwardedFor", - extractHeader(request)("x-forwarded-for") - .orElse(extractHeader(request)("x_forwarded_for")) - .getOrElse("unknown")) - MDC.put("remoteAddress", extractHeader(request)("remote-address").getOrElse("unknown")) - MDC.put("userAgent", extractHeader(request)("user-agent").getOrElse("unknown")) - } - - protected def swaggerOverride(apiTypes: Seq[Type]) = { - new Swagger(baseUrl, Scheme.forValue(scheme), version, actorSystem, apiTypes, config) { - override def generateSwaggerJson: String = { - import io.swagger.models.Swagger - - import scala.collection.JavaConverters._ - - try { - val swagger: Swagger = reader.read(toJavaTypeSet(apiTypes).asJava) - - // Removing trailing spaces - swagger.setPaths( - swagger.getPaths.asScala - .map { - case (key, path) => - key.trim -> path - } - .toMap - .asJava) - - Json.pretty().writeValueAsString(swagger) - } catch { - case NonFatal(t) => { - logger.error("Issue with creating swagger.json", t) - throw t - } - } - } - } - } - - /** - * Override me for custom exception handling - * - * @return Exception handling route for exception type - */ - protected def exceptionHandler = PartialFunction[Throwable, Route] { - - case is: IllegalStateException => - ctx => - log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) - errorResponse(ctx, BadRequest, message = is.getMessage, is)(ctx) - - case cm: ConcurrentModificationException => - ctx => - log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) - errorResponse(ctx, Conflict, "Resource was changed concurrently, try requesting a newer version", cm)(ctx) - - case se: SQLException => - ctx => - log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) - errorResponse(ctx, InternalServerError, "Data access error", se)(ctx) - - case t: Throwable => - ctx => - log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) - errorResponse(ctx, InternalServerError, t.getMessage, t)(ctx) - } - - protected def errorResponse[T <: Throwable](ctx: RequestContext, - statusCode: StatusCode, - message: String, - exception: T): Route = { - - val trackingId = rest.extractTrackingId(ctx.request) - val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(ctx.request)) - - MDC.put("trackingId", trackingId) - - optionalHeaderValueByType[Origin](()) { originHeader => - val responseHeaders = List[HttpHeader](tracingHeader, - allowOrigin(originHeader), - `Access-Control-Allow-Headers`(allowedHeaders: _*), - `Access-Control-Expose-Headers`(allowedHeaders: _*)) - - respondWithHeaders(responseHeaders) { - complete(HttpResponse(statusCode, entity = message)) - } - } - } - - protected def versionRoute(version: String, gitHash: String, startupTime: Time): Route = { - import spray.json._ - import DefaultJsonProtocol._ - import SprayJsonSupport._ - - path("version") { - val currentTime = time.currentTime().millis - complete( - Map( - "version" -> version.toJson, - "gitHash" -> gitHash.toJson, - "modules" -> modules.map(_.name).toJson, - "dependencies" -> collectAppDependencies().toJson, - "startupTime" -> startupTime.millis.toString.toJson, - "serverTime" -> currentTime.toString.toJson, - "uptime" -> (currentTime - startupTime.millis).toString.toJson - ).toJson) - } - } - - protected def collectAppDependencies(): Map[String, String] = { - - def serviceWithLocation(serviceName: String): (String, String) = - serviceName -> Try(config.getString(s"services.$serviceName.baseUrl")).getOrElse("not-detected") - - modules.flatMap(module => module.serviceDiscovery.getUsedServices.map(serviceWithLocation).toSeq).toMap - } - - protected def healthRoute: Route = { - import spray.json._ - import DefaultJsonProtocol._ - import SprayJsonSupport._ - import spray.json._ - - val memoryUsage = SystemStats.memoryUsage - val gcStats = SystemStats.garbageCollectorStats - - path("health") { - complete( - Map( - "availableProcessors" -> SystemStats.availableProcessors.toJson, - "memoryUsage" -> Map( - "free" -> memoryUsage.free.toJson, - "total" -> memoryUsage.total.toJson, - "max" -> memoryUsage.max.toJson - ).toJson, - "gcStats" -> Map( - "garbageCollectionTime" -> gcStats.garbageCollectionTime.toJson, - "totalGarbageCollections" -> gcStats.totalGarbageCollections.toJson - ).toJson, - "fileSystemSpace" -> SystemStats.fileSystemSpace.map { f => - Map("path" -> f.path.toJson, - "freeSpace" -> f.freeSpace.toJson, - "totalSpace" -> f.totalSpace.toJson, - "usableSpace" -> f.usableSpace.toJson) - }.toJson, - "operatingSystem" -> SystemStats.operatingSystemStats.toJson - )) - } - } - - /** - * Initializes services - */ - protected def activateServices(services: Seq[Module]): Unit = { - services.foreach { service => - Console.print(s"Service ${service.name} starts ...") - try { - service.activate() - } catch { - case t: Throwable => - log.error(s"Service ${service.name} failed to activate", t) - Console.print(" Failed! (check log)") - } - Console.print(" Done\n") - } - } - - /** - * Schedules services to be deactivated on the app shutdown - */ - protected def scheduleServicesDeactivation(services: Seq[Module]) = { - Runtime.getRuntime.addShutdownHook(new Thread() { - override def run(): Unit = { - services.foreach { service => - Console.print(s"Service ${service.name} shutting down ...\n") - try { - service.deactivate() - } catch { - case t: Throwable => - log.error(s"Service ${service.name} failed to deactivate", t) - Console.print(" Failed! (check log)") - } - Console.print(s"Service ${service.name} is shut down\n") - } - } - }) - } - } - - trait Module { - val name: String - def route: Route - def routeTypes: Seq[Type] - - val serviceDiscovery: ServiceDiscovery with SavingUsedServiceDiscovery = new NoServiceDiscovery() - - def activate(): Unit = {} - def deactivate(): Unit = {} - } - - class EmptyModule extends Module { - val name = "Nothing" - def route: Route = complete(StatusCodes.OK) - def routeTypes = Seq.empty[Type] - } - - class SimpleModule(val name: String, val route: Route, routeType: Type) extends Module { - def routeTypes: Seq[Type] = Seq(routeType) - } - - /** - * Module implementation which may be used to composed a few - * - * @param name more general name of the composite module, - * must be provided as there is no good way to automatically - * generalize the name from the composed modules' names - * @param modules modules to compose into a single one - */ - class CompositeModule(val name: String, modules: Seq[Module]) extends Module with RouteConcatenation { - - def route: Route = RouteConcatenation.concat(modules.map(_.route): _*) - def routeTypes = modules.flatMap(_.routeTypes) - - override def activate() = modules.foreach(_.activate()) - override def deactivate() = modules.reverse.foreach(_.deactivate()) - } -} diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala new file mode 100644 index 0000000..b73f426 --- /dev/null +++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala @@ -0,0 +1,361 @@ +package xyz.driver.core.app + +import java.sql.SQLException + +import akka.actor.ActorSystem +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.StatusCodes.{BadRequest, Conflict, InternalServerError, MethodNotAllowed} +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server.RouteResult.route2HandlerFlow +import akka.http.scaladsl.server._ +import akka.http.scaladsl.{Http, HttpExt} +import akka.stream.ActorMaterializer +import com.github.swagger.akka.SwaggerHttpService.{logger, toJavaTypeSet} +import com.typesafe.config.Config +import com.typesafe.scalalogging.Logger +import io.swagger.models.Scheme +import io.swagger.util.Json +import org.slf4j.{LoggerFactory, MDC} +import xyz.driver.core +import xyz.driver.core.rest +import xyz.driver.core.rest.{ContextHeaders, Swagger} +import xyz.driver.core.stats.SystemStats +import xyz.driver.core.time.Time +import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider} +import xyz.driver.tracing.TracingDirectives.trace +import xyz.driver.tracing.{NoTracer, Tracer} + +import scala.compat.Platform.ConcurrentModificationException +import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.reflect.runtime.universe._ +import scala.util.Try +import scala.util.control.NonFatal +import scalaz.Scalaz.stringInstance +import scalaz.syntax.equal._ + +class DriverApp(appName: String, + version: String, + gitHash: String, + modules: Seq[Module], + time: TimeProvider = new SystemTimeProvider(), + log: Logger = Logger(LoggerFactory.getLogger(classOf[DriverApp])), + config: Config = core.config.loadDefaultConfig, + interface: String = "::0", + baseUrl: String = "localhost:8080", + scheme: String = "http", + port: Int = 8080, + tracer: Tracer = NoTracer)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) { + + implicit private lazy val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) + private lazy val http: HttpExt = Http()(actorSystem) + val appEnvironment: String = config.getString("application.environment") + + def run(): Unit = { + activateServices(modules) + scheduleServicesDeactivation(modules) + bindHttp(modules) + Console.print(s"${this.getClass.getName} App is started\n") + } + + def stop(): Unit = { + http.shutdownAllConnectionPools().onComplete { _ => + Await.result(tracer.close(), 15.seconds) // flush out any remaining traces from the buffer + val _ = actorSystem.terminate() + val terminated = Await.result(actorSystem.whenTerminated, 30.seconds) + val addressTerminated = if (terminated.addressTerminated) "is" else "is not" + Console.print(s"${this.getClass.getName} App $addressTerminated stopped ") + } + } + + private def extractHeader(request: HttpRequest)(headerName: String): Option[String] = + request.headers.find(_.name().toLowerCase === headerName).map(_.value()) + + private def allowOrigin(originHeader: Option[Origin]) = + `Access-Control-Allow-Origin`( + originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) + + protected implicit def rejectionHandler: RejectionHandler = + RejectionHandler + .newBuilder() + .handleAll[MethodRejection] { rejections => + val methods = rejections map (_.supported) + lazy val names = methods map (_.name) mkString ", " + + options { ctx => + optionalHeaderValueByType[Origin](()) { originHeader => + respondWithHeaders(List[HttpHeader]( + Allow(methods), + `Access-Control-Allow-Methods`(methods), + allowOrigin(originHeader), + `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*), + `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*) + )) { + complete(s"Supported methods: $names.") + } + }(ctx) + } ~ + complete(MethodNotAllowed -> s"HTTP method not allowed, supported methods: $names!") + } + .result() + + protected def bindHttp(modules: Seq[Module]): Unit = { + val serviceTypes = modules.flatMap(_.routeTypes) + val swaggerService = swaggerOverride(serviceTypes) + val swaggerRoutes = swaggerService.routes ~ swaggerService.swaggerUI + val versionRt = versionRoute(version, gitHash, time.currentTime()) + + val _ = Future { + http.bindAndHandle( + route2HandlerFlow(extractHost { origin => + trace(tracer) { + extractClientIP { ip => + optionalHeaderValueByType[Origin](()) { + originHeader => + { + ctx => + val trackingId = rest.extractTrackingId(ctx.request) + MDC.put("trackingId", trackingId) + + val updatedStacktrace = + (rest.extractStacktrace(ctx.request) ++ Array(appName)).mkString("->") + MDC.put("stack", updatedStacktrace) + + storeRequestContextToMdc(ctx.request, origin, ip) + + def requestLogging: Future[Unit] = Future { + log.info( + s"""Received request {"method":"${ctx.request.method.value}","url": "${ctx.request.uri}"}""") + } + + val contextWithTrackingId = + ctx.withRequest( + ctx.request + .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) + .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) + + handleExceptions(ExceptionHandler(exceptionHandler))({ + c => + requestLogging.flatMap { _ => + val trackingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) + + val responseHeaders = List[HttpHeader]( + trackingHeader, + allowOrigin(originHeader), + `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*), + `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*) + ) + + respondWithHeaders(responseHeaders) { + modules.map(_.route).foldLeft(versionRt ~ healthRoute ~ swaggerRoutes)(_ ~ _) + }(c) + } + })(contextWithTrackingId) + } + } + } + } + }), + interface, + port + )(materializer) + } + } + + private def storeRequestContextToMdc(request: HttpRequest, origin: String, ip: RemoteAddress): Unit = { + + MDC.put("origin", origin) + MDC.put("ip", ip.toOption.map(_.getHostAddress).getOrElse("unknown")) + MDC.put("remoteHost", ip.toOption.map(_.getHostName).getOrElse("unknown")) + + MDC.put("xForwardedFor", + extractHeader(request)("x-forwarded-for") + .orElse(extractHeader(request)("x_forwarded_for")) + .getOrElse("unknown")) + MDC.put("remoteAddress", extractHeader(request)("remote-address").getOrElse("unknown")) + MDC.put("userAgent", extractHeader(request)("user-agent").getOrElse("unknown")) + } + + protected def swaggerOverride(apiTypes: Seq[Type]): Swagger = { + new Swagger(baseUrl, Scheme.forValue(scheme), version, actorSystem, apiTypes, config) { + override def generateSwaggerJson: String = { + import io.swagger.models.Swagger + + import scala.collection.JavaConverters._ + + try { + val swagger: Swagger = reader.read(toJavaTypeSet(apiTypes).asJava) + + // Removing trailing spaces + swagger.setPaths( + swagger.getPaths.asScala + .map { + case (key, path) => + key.trim -> path + } + .toMap + .asJava) + + Json.pretty().writeValueAsString(swagger) + } catch { + case NonFatal(t) => + logger.error("Issue with creating swagger.json", t) + throw t + } + } + } + } + + /** + * Override me for custom exception handling + * + * @return Exception handling route for exception type + */ + protected def exceptionHandler: PartialFunction[Throwable, Route] = { + + case is: IllegalStateException => + ctx => + log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) + errorResponse(ctx, BadRequest, message = is.getMessage, is)(ctx) + + case cm: ConcurrentModificationException => + ctx => + log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) + errorResponse(ctx, Conflict, "Resource was changed concurrently, try requesting a newer version", cm)(ctx) + + case se: SQLException => + ctx => + log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) + errorResponse(ctx, InternalServerError, "Data access error", se)(ctx) + + case t: Throwable => + ctx => + log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) + errorResponse(ctx, InternalServerError, t.getMessage, t)(ctx) + } + + protected def errorResponse[T <: Throwable](ctx: RequestContext, + statusCode: StatusCode, + message: String, + exception: T): Route = { + + val trackingId = rest.extractTrackingId(ctx.request) + val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(ctx.request)) + + MDC.put("trackingId", trackingId) + + optionalHeaderValueByType[Origin](()) { originHeader => + val responseHeaders = List[HttpHeader]( + tracingHeader, + allowOrigin(originHeader), + `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*), + `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*) + ) + + respondWithHeaders(responseHeaders) { + complete(HttpResponse(statusCode, entity = message)) + } + } + } + + protected def versionRoute(version: String, gitHash: String, startupTime: Time): Route = { + import spray.json._ + import DefaultJsonProtocol._ + import SprayJsonSupport._ + + path("version") { + val currentTime = time.currentTime().millis + complete( + Map( + "version" -> version.toJson, + "gitHash" -> gitHash.toJson, + "modules" -> modules.map(_.name).toJson, + "dependencies" -> collectAppDependencies().toJson, + "startupTime" -> startupTime.millis.toString.toJson, + "serverTime" -> currentTime.toString.toJson, + "uptime" -> (currentTime - startupTime.millis).toString.toJson + ).toJson) + } + } + + protected def collectAppDependencies(): Map[String, String] = { + + def serviceWithLocation(serviceName: String): (String, String) = + serviceName -> Try(config.getString(s"services.$serviceName.baseUrl")).getOrElse("not-detected") + + modules.flatMap(module => module.serviceDiscovery.getUsedServices.map(serviceWithLocation).toSeq).toMap + } + + protected def healthRoute: Route = { + import spray.json._ + import DefaultJsonProtocol._ + import SprayJsonSupport._ + import spray.json._ + + val memoryUsage = SystemStats.memoryUsage + val gcStats = SystemStats.garbageCollectorStats + + path("health") { + complete( + Map( + "availableProcessors" -> SystemStats.availableProcessors.toJson, + "memoryUsage" -> Map( + "free" -> memoryUsage.free.toJson, + "total" -> memoryUsage.total.toJson, + "max" -> memoryUsage.max.toJson + ).toJson, + "gcStats" -> Map( + "garbageCollectionTime" -> gcStats.garbageCollectionTime.toJson, + "totalGarbageCollections" -> gcStats.totalGarbageCollections.toJson + ).toJson, + "fileSystemSpace" -> SystemStats.fileSystemSpace.map { f => + Map("path" -> f.path.toJson, + "freeSpace" -> f.freeSpace.toJson, + "totalSpace" -> f.totalSpace.toJson, + "usableSpace" -> f.usableSpace.toJson) + }.toJson, + "operatingSystem" -> SystemStats.operatingSystemStats.toJson + )) + } + } + + /** + * Initializes services + */ + protected def activateServices(services: Seq[Module]): Unit = { + services.foreach { service => + Console.print(s"Service ${service.name} starts ...") + try { + service.activate() + } catch { + case t: Throwable => + log.error(s"Service ${service.name} failed to activate", t) + Console.print(" Failed! (check log)") + } + Console.print(" Done\n") + } + } + + /** + * Schedules services to be deactivated on the app shutdown + */ + protected def scheduleServicesDeactivation(services: Seq[Module]): Unit = { + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = { + services.foreach { service => + Console.print(s"Service ${service.name} shutting down ...\n") + try { + service.deactivate() + } catch { + case t: Throwable => + log.error(s"Service ${service.name} failed to deactivate", t) + Console.print(" Failed! (check log)") + } + Console.print(s"Service ${service.name} is shut down\n") + } + } + }) + } +} diff --git a/src/main/scala/xyz/driver/core/app/Module.scala b/src/main/scala/xyz/driver/core/app/Module.scala new file mode 100644 index 0000000..933b408 --- /dev/null +++ b/src/main/scala/xyz/driver/core/app/Module.scala @@ -0,0 +1,53 @@ +package xyz.driver.core.app + +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.server.Directives.complete +import akka.http.scaladsl.server.{Route, RouteConcatenation} +import xyz.driver.core.rest.{NoServiceDiscovery, SavingUsedServiceDiscovery, ServiceDiscovery} + +import scala.reflect.runtime.universe._ + +trait Module { + val name: String + def route: Route + def routeTypes: Seq[Type] + + val serviceDiscovery: ServiceDiscovery with SavingUsedServiceDiscovery = new NoServiceDiscovery() + + def activate(): Unit = {} + def deactivate(): Unit = {} +} + +object Module { + + class EmptyModule extends Module { + override val name: String = "Nothing" + + override def route: Route = complete(StatusCodes.OK) + + override def routeTypes: Seq[Type] = Seq.empty[Type] + } + + class SimpleModule(override val name: String, override val route: Route, routeType: Type) extends Module { + def routeTypes: Seq[Type] = Seq(routeType) + } + + /** + * Module implementation which may be used to composed a few + * + * @param name more general name of the composite module, + * must be provided as there is no good way to automatically + * generalize the name from the composed modules' names + * @param modules modules to compose into a single one + */ + class CompositeModule(override val name: String, modules: Seq[Module]) extends Module with RouteConcatenation { + + override def route: Route = RouteConcatenation.concat(modules.map(_.route): _*) + + override def routeTypes: Seq[Type] = modules.flatMap(_.routeTypes) + + override def activate(): Unit = modules.foreach(_.activate()) + + override def deactivate(): Unit = modules.reverse.foreach(_.deactivate()) + } +} diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala index 4c8e13c..e6eb8d6 100644 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -5,27 +5,54 @@ import akka.http.scaladsl.server.Directives._ import akka.http.scaladsl.server._ import akka.stream.scaladsl.Flow import akka.util.ByteString +import xyz.driver.tracing.TracingDirectives import scalaz.Scalaz.{intInstance, stringInstance} import scalaz.syntax.equal._ package object rest { object ContextHeaders { - val AuthenticationTokenHeader = "Authorization" - val PermissionsTokenHeader = "Permissions" - val AuthenticationHeaderPrefix = "Bearer" - val TrackingIdHeader = "X-Trace" - val StacktraceHeader = "X-Stacktrace" - val TracingHeader = trace.TracingHeaderKey + val AuthenticationTokenHeader: String = "Authorization" + val PermissionsTokenHeader: String = "Permissions" + val AuthenticationHeaderPrefix: String = "Bearer" + val TrackingIdHeader: String = "X-Trace" + val StacktraceHeader: String = "X-Stacktrace" + val TraceHeaderName: String = TracingDirectives.TraceHeaderName + val SpanHeaderName: String = TracingDirectives.SpanHeaderName } object AuthProvider { - val AuthenticationTokenHeader = ContextHeaders.AuthenticationTokenHeader - val PermissionsTokenHeader = ContextHeaders.PermissionsTokenHeader - val SetAuthenticationTokenHeader = "set-authorization" - val SetPermissionsTokenHeader = "set-permissions" + val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader + val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader + val SetAuthenticationTokenHeader: String = "set-authorization" + val SetPermissionsTokenHeader: String = "set-permissions" } + val AllowedHeaders: Seq[String] = + Seq( + "Origin", + "X-Requested-With", + "Content-Type", + "Content-Length", + "Accept", + "X-Trace", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Origin", + "Access-Control-Allow-Headers", + "Server", + "Date", + ContextHeaders.TrackingIdHeader, + ContextHeaders.TraceHeaderName, + ContextHeaders.SpanHeaderName, + ContextHeaders.StacktraceHeader, + ContextHeaders.AuthenticationTokenHeader, + "X-Frame-Options", + "X-Content-Type-Options", + "Strict-Transport-Security", + AuthProvider.SetAuthenticationTokenHeader, + AuthProvider.SetPermissionsTokenHeader + ) + def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) def extractServiceContext(request: HttpRequest): ServiceRequestContext = @@ -44,7 +71,7 @@ package object rest { request.headers.filter { h => h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || - h.name === ContextHeaders.TracingHeader + h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName } map { header => if (header.name === ContextHeaders.AuthenticationTokenHeader) { header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim -- cgit v1.2.3 From 8a34b953fa480bfea7e80a46eb4de6b20b4bca68 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Mon, 16 Oct 2017 09:47:54 -0700 Subject: Move AuthorizedServiceRequestContext into ServiceRequestContext.scala --- .../rest/AuthorizedServiceRequestContext.scala | 27 ---------------------- .../driver/core/rest/ServiceRequestContext.scala | 23 ++++++++++++++++++ 2 files changed, 23 insertions(+), 27 deletions(-) delete mode 100644 src/main/scala/xyz/driver/core/rest/AuthorizedServiceRequestContext.scala (limited to 'src/main/scala/xyz/driver/core') diff --git a/src/main/scala/xyz/driver/core/rest/AuthorizedServiceRequestContext.scala b/src/main/scala/xyz/driver/core/rest/AuthorizedServiceRequestContext.scala deleted file mode 100644 index 1cf62c9..0000000 --- a/src/main/scala/xyz/driver/core/rest/AuthorizedServiceRequestContext.scala +++ /dev/null @@ -1,27 +0,0 @@ -package xyz.driver.core.rest - -import xyz.driver.core.auth.{PermissionsToken, User} -import xyz.driver.core.generators - -class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString, - override val contextHeaders: Map[String, String] = - Map.empty[String, String], - val authenticatedUser: U) - extends ServiceRequestContext { - - def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = - new AuthorizedServiceRequestContext[U]( - trackingId, - contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), - authenticatedUser) - - override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() - - override def equals(obj: Any): Boolean = obj match { - case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser - case _ => false - } - - override def toString: String = - s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" -} diff --git a/src/main/scala/xyz/driver/core/rest/ServiceRequestContext.scala b/src/main/scala/xyz/driver/core/rest/ServiceRequestContext.scala index 5235da6..4020d57 100644 --- a/src/main/scala/xyz/driver/core/rest/ServiceRequestContext.scala +++ b/src/main/scala/xyz/driver/core/rest/ServiceRequestContext.scala @@ -37,3 +37,26 @@ class ServiceRequestContext(val trackingId: String = generators.nextUuid().toStr override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" } + +class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString, + override val contextHeaders: Map[String, String] = + Map.empty[String, String], + val authenticatedUser: U) + extends ServiceRequestContext { + + def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext[U]( + trackingId, + contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), + authenticatedUser) + + override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() + + override def equals(obj: Any): Boolean = obj match { + case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser + case _ => false + } + + override def toString: String = + s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" +} -- cgit v1.2.3 From 354a21d1a72867b352edbd0aa25b4980938d2749 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Mon, 16 Oct 2017 11:54:59 -0700 Subject: Consolidate files --- src/main/scala/xyz/driver/core/app/Module.scala | 3 ++- .../scala/xyz/driver/core/rest/HttpClient.scala | 9 --------- .../xyz/driver/core/rest/NoServiceDiscovery.scala | 9 --------- .../scala/xyz/driver/core/rest/Pagination.scala | 3 --- .../core/rest/SavingUsedServiceDiscovery.scala | 14 ------------- src/main/scala/xyz/driver/core/rest/Service.scala | 3 --- .../xyz/driver/core/rest/ServiceDiscovery.scala | 18 +++++++++++++++++ .../xyz/driver/core/rest/ServiceTransport.scala | 13 ------------ src/main/scala/xyz/driver/core/rest/package.scala | 23 +++++++++++++++++++--- 9 files changed, 40 insertions(+), 55 deletions(-) delete mode 100644 src/main/scala/xyz/driver/core/rest/HttpClient.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/NoServiceDiscovery.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/Pagination.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/SavingUsedServiceDiscovery.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/Service.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/ServiceTransport.scala (limited to 'src/main/scala/xyz/driver/core') diff --git a/src/main/scala/xyz/driver/core/app/Module.scala b/src/main/scala/xyz/driver/core/app/Module.scala index 933b408..3aea876 100644 --- a/src/main/scala/xyz/driver/core/app/Module.scala +++ b/src/main/scala/xyz/driver/core/app/Module.scala @@ -3,7 +3,8 @@ package xyz.driver.core.app import akka.http.scaladsl.model.StatusCodes import akka.http.scaladsl.server.Directives.complete import akka.http.scaladsl.server.{Route, RouteConcatenation} -import xyz.driver.core.rest.{NoServiceDiscovery, SavingUsedServiceDiscovery, ServiceDiscovery} +import xyz.driver.core.rest.ServiceDiscovery +import xyz.driver.core.rest.ServiceDiscovery.{NoServiceDiscovery, SavingUsedServiceDiscovery} import scala.reflect.runtime.universe._ diff --git a/src/main/scala/xyz/driver/core/rest/HttpClient.scala b/src/main/scala/xyz/driver/core/rest/HttpClient.scala deleted file mode 100644 index 6f6fda0..0000000 --- a/src/main/scala/xyz/driver/core/rest/HttpClient.scala +++ /dev/null @@ -1,9 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.model.{HttpRequest, HttpResponse} - -import scala.concurrent.Future - -trait HttpClient { - def makeRequest(request: HttpRequest): Future[HttpResponse] -} diff --git a/src/main/scala/xyz/driver/core/rest/NoServiceDiscovery.scala b/src/main/scala/xyz/driver/core/rest/NoServiceDiscovery.scala deleted file mode 100644 index 9d2febd..0000000 --- a/src/main/scala/xyz/driver/core/rest/NoServiceDiscovery.scala +++ /dev/null @@ -1,9 +0,0 @@ -package xyz.driver.core.rest - -import xyz.driver.core.Name - -class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery { - - def discover[T <: Service](serviceName: Name[Service]): T = - throw new IllegalArgumentException(s"Service with name $serviceName is unknown") -} diff --git a/src/main/scala/xyz/driver/core/rest/Pagination.scala b/src/main/scala/xyz/driver/core/rest/Pagination.scala deleted file mode 100644 index f97660f..0000000 --- a/src/main/scala/xyz/driver/core/rest/Pagination.scala +++ /dev/null @@ -1,3 +0,0 @@ -package xyz.driver.core.rest - -final case class Pagination(pageSize: Int, pageNumber: Int) diff --git a/src/main/scala/xyz/driver/core/rest/SavingUsedServiceDiscovery.scala b/src/main/scala/xyz/driver/core/rest/SavingUsedServiceDiscovery.scala deleted file mode 100644 index 8018bdf..0000000 --- a/src/main/scala/xyz/driver/core/rest/SavingUsedServiceDiscovery.scala +++ /dev/null @@ -1,14 +0,0 @@ -package xyz.driver.core.rest - -import xyz.driver.core.Name - -trait SavingUsedServiceDiscovery { - - private val usedServices = new scala.collection.mutable.HashSet[String]() - - def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized { - usedServices += serviceName.value - } - - def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet } -} diff --git a/src/main/scala/xyz/driver/core/rest/Service.scala b/src/main/scala/xyz/driver/core/rest/Service.scala deleted file mode 100644 index 8216ab7..0000000 --- a/src/main/scala/xyz/driver/core/rest/Service.scala +++ /dev/null @@ -1,3 +0,0 @@ -package xyz.driver.core.rest - -trait Service diff --git a/src/main/scala/xyz/driver/core/rest/ServiceDiscovery.scala b/src/main/scala/xyz/driver/core/rest/ServiceDiscovery.scala index f0f3b5b..5f589a9 100644 --- a/src/main/scala/xyz/driver/core/rest/ServiceDiscovery.scala +++ b/src/main/scala/xyz/driver/core/rest/ServiceDiscovery.scala @@ -6,3 +6,21 @@ trait ServiceDiscovery { def discover[T <: Service](serviceName: Name[Service]): T } + +object ServiceDiscovery { + trait SavingUsedServiceDiscovery { + private val usedServices = new scala.collection.mutable.HashSet[String]() + + def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized { + usedServices += serviceName.value + } + + def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet } + } + + class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery { + + def discover[T <: Service](serviceName: Name[Service]): T = + throw new IllegalArgumentException(s"Service with name $serviceName is unknown") + } +} diff --git a/src/main/scala/xyz/driver/core/rest/ServiceTransport.scala b/src/main/scala/xyz/driver/core/rest/ServiceTransport.scala deleted file mode 100644 index 9c0c429..0000000 --- a/src/main/scala/xyz/driver/core/rest/ServiceTransport.scala +++ /dev/null @@ -1,13 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity} -import akka.http.scaladsl.unmarshalling.Unmarshal - -import scala.concurrent.Future - -trait ServiceTransport { - - def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] - - def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] -} diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala index e6eb8d6..6019c33 100644 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -1,16 +1,33 @@ -package xyz.driver.core +package xyz.driver.core.rest -import akka.http.scaladsl.model.HttpRequest +import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity} import akka.http.scaladsl.server.Directives._ import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.Unmarshal import akka.stream.scaladsl.Flow import akka.util.ByteString import xyz.driver.tracing.TracingDirectives +import scala.concurrent.Future import scalaz.Scalaz.{intInstance, stringInstance} import scalaz.syntax.equal._ -package object rest { +trait Service + +trait HttpClient { + def makeRequest(request: HttpRequest): Future[HttpResponse] +} + +trait ServiceTransport { + + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] + + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] +} + +final case class Pagination(pageSize: Int, pageNumber: Int) + +object `package` { object ContextHeaders { val AuthenticationTokenHeader: String = "Authorization" val PermissionsTokenHeader: String = "Permissions" -- cgit v1.2.3 From 3330c62e4ce6e775313a3bd7ce74aed871cd06d0 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Mon, 16 Oct 2017 12:58:14 -0700 Subject: Move Implicits to package.scala --- src/main/scala/xyz/driver/core/rest/Implicits.scala | 17 ----------------- src/main/scala/xyz/driver/core/rest/package.scala | 19 ++++++++++++++++++- 2 files changed, 18 insertions(+), 18 deletions(-) delete mode 100644 src/main/scala/xyz/driver/core/rest/Implicits.scala (limited to 'src/main/scala/xyz/driver/core') diff --git a/src/main/scala/xyz/driver/core/rest/Implicits.scala b/src/main/scala/xyz/driver/core/rest/Implicits.scala deleted file mode 100644 index 8b499dd..0000000 --- a/src/main/scala/xyz/driver/core/rest/Implicits.scala +++ /dev/null @@ -1,17 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} -import akka.http.scaladsl.model.StatusCodes - -import scala.concurrent.Future -import scalaz.{Functor, OptionT} - -object Implicits { - implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { - def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( - implicit F: Functor[Future], - em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { - optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) - } - } -} diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala index 6019c33..17cc3bb 100644 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -1,6 +1,8 @@ package xyz.driver.core.rest -import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity} +import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} +import akka.http.scaladsl.model.headers.{HttpOriginRange, Origin, `Access-Control-Allow-Origin`} +import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity, StatusCodes} import akka.http.scaladsl.server.Directives._ import akka.http.scaladsl.server._ import akka.http.scaladsl.unmarshalling.Unmarshal @@ -9,6 +11,7 @@ import akka.util.ByteString import xyz.driver.tracing.TracingDirectives import scala.concurrent.Future +import scalaz.{Functor, OptionT} import scalaz.Scalaz.{intInstance, stringInstance} import scalaz.syntax.equal._ @@ -27,6 +30,16 @@ trait ServiceTransport { final case class Pagination(pageSize: Int, pageNumber: Int) +object Implicits { + implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { + def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( + implicit F: Functor[Future], + em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { + optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) + } + } +} + object `package` { object ContextHeaders { val AuthenticationTokenHeader: String = "Authorization" @@ -70,6 +83,10 @@ object `package` { AuthProvider.SetPermissionsTokenHeader ) + def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` = + `Access-Control-Allow-Origin`( + originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) + def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) def extractServiceContext(request: HttpRequest): ServiceRequestContext = -- cgit v1.2.3 From f1f4183bb7e40347e15347de87e06c1e1d854827 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Tue, 17 Oct 2017 13:56:24 -0700 Subject: Move implicit class extension to package object --- src/main/scala/xyz/driver/core/rest/package.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'src/main/scala/xyz/driver/core') diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala index 17cc3bb..942ca3a 100644 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -30,7 +30,7 @@ trait ServiceTransport { final case class Pagination(pageSize: Int, pageNumber: Int) -object Implicits { +object `package` { implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( implicit F: Functor[Future], @@ -38,9 +38,7 @@ object Implicits { optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) } } -} -object `package` { object ContextHeaders { val AuthenticationTokenHeader: String = "Authorization" val PermissionsTokenHeader: String = "Permissions" -- cgit v1.2.3 From 1aaaf7a5ecf2cd28350fff872e334f9f6186966a Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Tue, 17 Oct 2017 14:42:39 -0700 Subject: Rename files containing multiple classes to lowercase --- src/main/scala/xyz/driver/core/app/Module.scala | 54 ------------------- src/main/scala/xyz/driver/core/app/module.scala | 50 +++++++++++++++++ .../xyz/driver/core/rest/ServiceDiscovery.scala | 26 --------- .../driver/core/rest/ServiceRequestContext.scala | 62 ---------------------- .../xyz/driver/core/rest/serviceDiscovery.scala | 24 +++++++++ .../driver/core/rest/serviceRequestContext.scala | 62 ++++++++++++++++++++++ 6 files changed, 136 insertions(+), 142 deletions(-) delete mode 100644 src/main/scala/xyz/driver/core/app/Module.scala create mode 100644 src/main/scala/xyz/driver/core/app/module.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/ServiceDiscovery.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/ServiceRequestContext.scala create mode 100644 src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala create mode 100644 src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala (limited to 'src/main/scala/xyz/driver/core') diff --git a/src/main/scala/xyz/driver/core/app/Module.scala b/src/main/scala/xyz/driver/core/app/Module.scala deleted file mode 100644 index 3aea876..0000000 --- a/src/main/scala/xyz/driver/core/app/Module.scala +++ /dev/null @@ -1,54 +0,0 @@ -package xyz.driver.core.app - -import akka.http.scaladsl.model.StatusCodes -import akka.http.scaladsl.server.Directives.complete -import akka.http.scaladsl.server.{Route, RouteConcatenation} -import xyz.driver.core.rest.ServiceDiscovery -import xyz.driver.core.rest.ServiceDiscovery.{NoServiceDiscovery, SavingUsedServiceDiscovery} - -import scala.reflect.runtime.universe._ - -trait Module { - val name: String - def route: Route - def routeTypes: Seq[Type] - - val serviceDiscovery: ServiceDiscovery with SavingUsedServiceDiscovery = new NoServiceDiscovery() - - def activate(): Unit = {} - def deactivate(): Unit = {} -} - -object Module { - - class EmptyModule extends Module { - override val name: String = "Nothing" - - override def route: Route = complete(StatusCodes.OK) - - override def routeTypes: Seq[Type] = Seq.empty[Type] - } - - class SimpleModule(override val name: String, override val route: Route, routeType: Type) extends Module { - def routeTypes: Seq[Type] = Seq(routeType) - } - - /** - * Module implementation which may be used to composed a few - * - * @param name more general name of the composite module, - * must be provided as there is no good way to automatically - * generalize the name from the composed modules' names - * @param modules modules to compose into a single one - */ - class CompositeModule(override val name: String, modules: Seq[Module]) extends Module with RouteConcatenation { - - override def route: Route = RouteConcatenation.concat(modules.map(_.route): _*) - - override def routeTypes: Seq[Type] = modules.flatMap(_.routeTypes) - - override def activate(): Unit = modules.foreach(_.activate()) - - override def deactivate(): Unit = modules.reverse.foreach(_.deactivate()) - } -} diff --git a/src/main/scala/xyz/driver/core/app/module.scala b/src/main/scala/xyz/driver/core/app/module.scala new file mode 100644 index 0000000..c6f979f --- /dev/null +++ b/src/main/scala/xyz/driver/core/app/module.scala @@ -0,0 +1,50 @@ +package xyz.driver.core.app + +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.server.Directives.complete +import akka.http.scaladsl.server.{Route, RouteConcatenation} +import xyz.driver.core.rest.{NoServiceDiscovery, SavingUsedServiceDiscovery, ServiceDiscovery} + +import scala.reflect.runtime.universe._ + +trait Module { + val name: String + def route: Route + def routeTypes: Seq[Type] + + val serviceDiscovery: ServiceDiscovery with SavingUsedServiceDiscovery = new NoServiceDiscovery() + + def activate(): Unit = {} + def deactivate(): Unit = {} +} + +class EmptyModule extends Module { + override val name: String = "Nothing" + + override def route: Route = complete(StatusCodes.OK) + + override def routeTypes: Seq[Type] = Seq.empty[Type] +} + +class SimpleModule(override val name: String, override val route: Route, routeType: Type) extends Module { + def routeTypes: Seq[Type] = Seq(routeType) +} + +/** + * Module implementation which may be used to composed a few + * + * @param name more general name of the composite module, + * must be provided as there is no good way to automatically + * generalize the name from the composed modules' names + * @param modules modules to compose into a single one + */ +class CompositeModule(override val name: String, modules: Seq[Module]) extends Module with RouteConcatenation { + + override def route: Route = RouteConcatenation.concat(modules.map(_.route): _*) + + override def routeTypes: Seq[Type] = modules.flatMap(_.routeTypes) + + override def activate(): Unit = modules.foreach(_.activate()) + + override def deactivate(): Unit = modules.reverse.foreach(_.deactivate()) +} diff --git a/src/main/scala/xyz/driver/core/rest/ServiceDiscovery.scala b/src/main/scala/xyz/driver/core/rest/ServiceDiscovery.scala deleted file mode 100644 index 5f589a9..0000000 --- a/src/main/scala/xyz/driver/core/rest/ServiceDiscovery.scala +++ /dev/null @@ -1,26 +0,0 @@ -package xyz.driver.core.rest - -import xyz.driver.core.Name - -trait ServiceDiscovery { - - def discover[T <: Service](serviceName: Name[Service]): T -} - -object ServiceDiscovery { - trait SavingUsedServiceDiscovery { - private val usedServices = new scala.collection.mutable.HashSet[String]() - - def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized { - usedServices += serviceName.value - } - - def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet } - } - - class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery { - - def discover[T <: Service](serviceName: Name[Service]): T = - throw new IllegalArgumentException(s"Service with name $serviceName is unknown") - } -} diff --git a/src/main/scala/xyz/driver/core/rest/ServiceRequestContext.scala b/src/main/scala/xyz/driver/core/rest/ServiceRequestContext.scala deleted file mode 100644 index 4020d57..0000000 --- a/src/main/scala/xyz/driver/core/rest/ServiceRequestContext.scala +++ /dev/null @@ -1,62 +0,0 @@ -package xyz.driver.core.rest - -import xyz.driver.core.auth.{AuthToken, PermissionsToken, User} -import xyz.driver.core.generators - -import scalaz.Scalaz.{mapEqual, stringInstance} -import scalaz.syntax.equal._ - -class ServiceRequestContext(val trackingId: String = generators.nextUuid().toString, - val contextHeaders: Map[String, String] = Map.empty[String, String]) { - def authToken: Option[AuthToken] = - contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) - - def permissionsToken: Option[PermissionsToken] = - contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) - - def withAuthToken(authToken: AuthToken): ServiceRequestContext = - new ServiceRequestContext( - trackingId, - contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) - ) - - def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = - new AuthorizedServiceRequestContext( - trackingId, - contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), - user - ) - - override def hashCode(): Int = - Seq[Any](trackingId, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) - - override def equals(obj: Any): Boolean = obj match { - case ctx: ServiceRequestContext => trackingId === ctx.trackingId && contextHeaders === ctx.contextHeaders - case _ => false - } - - override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" -} - -class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString, - override val contextHeaders: Map[String, String] = - Map.empty[String, String], - val authenticatedUser: U) - extends ServiceRequestContext { - - def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = - new AuthorizedServiceRequestContext[U]( - trackingId, - contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), - authenticatedUser) - - override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() - - override def equals(obj: Any): Boolean = obj match { - case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser - case _ => false - } - - override def toString: String = - s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" -} diff --git a/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala b/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala new file mode 100644 index 0000000..55f1a2e --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala @@ -0,0 +1,24 @@ +package xyz.driver.core.rest + +import xyz.driver.core.Name + +trait ServiceDiscovery { + + def discover[T <: Service](serviceName: Name[Service]): T +} + +trait SavingUsedServiceDiscovery { + private val usedServices = new scala.collection.mutable.HashSet[String]() + + def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized { + usedServices += serviceName.value + } + + def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet } +} + +class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery { + + def discover[T <: Service](serviceName: Name[Service]): T = + throw new IllegalArgumentException(s"Service with name $serviceName is unknown") +} diff --git a/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala b/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala new file mode 100644 index 0000000..4020d57 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala @@ -0,0 +1,62 @@ +package xyz.driver.core.rest + +import xyz.driver.core.auth.{AuthToken, PermissionsToken, User} +import xyz.driver.core.generators + +import scalaz.Scalaz.{mapEqual, stringInstance} +import scalaz.syntax.equal._ + +class ServiceRequestContext(val trackingId: String = generators.nextUuid().toString, + val contextHeaders: Map[String, String] = Map.empty[String, String]) { + def authToken: Option[AuthToken] = + contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) + + def permissionsToken: Option[PermissionsToken] = + contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) + + def withAuthToken(authToken: AuthToken): ServiceRequestContext = + new ServiceRequestContext( + trackingId, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) + ) + + def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext( + trackingId, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), + user + ) + + override def hashCode(): Int = + Seq[Any](trackingId, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) + + override def equals(obj: Any): Boolean = obj match { + case ctx: ServiceRequestContext => trackingId === ctx.trackingId && contextHeaders === ctx.contextHeaders + case _ => false + } + + override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" +} + +class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString, + override val contextHeaders: Map[String, String] = + Map.empty[String, String], + val authenticatedUser: U) + extends ServiceRequestContext { + + def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext[U]( + trackingId, + contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), + authenticatedUser) + + override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() + + override def equals(obj: Any): Boolean = obj match { + case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser + case _ => false + } + + override def toString: String = + s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" +} -- cgit v1.2.3