From 16bdae27befd9cf3b723ad919ba2140b38d18c48 Mon Sep 17 00:00:00 2001 From: vlad Date: Tue, 1 Nov 2016 15:19:36 -0700 Subject: DIR-135 Consistent request context extraction --- src/main/scala/xyz/driver/core/rest.scala | 55 ++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 19 deletions(-) (limited to 'src/main/scala/xyz/driver/core/rest.scala') diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index eaf97db..c52d9e0 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -4,15 +4,13 @@ import akka.actor.ActorSystem import akka.http.scaladsl.Http import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers.RawHeader +import akka.http.scaladsl.server.RequestContext import akka.http.scaladsl.unmarshalling.Unmarshal import akka.stream.ActorMaterializer -import akka.stream.scaladsl.Flow -import akka.util.ByteString import com.github.swagger.akka.model._ import com.github.swagger.akka.{HasActorSystem, SwaggerHttpService} import com.typesafe.config.Config -import xyz.driver.core.auth.{AuthService, AuthToken} -import xyz.driver.core.crypto.Crypto +import xyz.driver.core.auth.AuthService import xyz.driver.core.logging.Logger import xyz.driver.core.stats.Stats import xyz.driver.core.time.TimeRange @@ -20,15 +18,41 @@ import xyz.driver.core.time.provider.TimeProvider import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} -import scalaz.{Failure => _, Success => _} +import scalaz.Scalaz.{Id => _, _} object rest { + object ContextHeaders { + val AuthenticationTokenHeader = AuthService.AuthenticationTokenHeader + val TrackingIdHeader = "l5d-ctx-trace" // https://linkerd.io/doc/0.7.4/linkerd/protocol-http/ + } + + final case class ServiceRequestContext(trackingId: String, contextHeaders: Map[String, String]) + + def serviceContext(ctx: RequestContext): ServiceRequestContext = { + ServiceRequestContext(extractTrackingId(ctx), extractContextHeaders(ctx)) + } + + def extractTrackingId(ctx: RequestContext): String = { + ctx.request.headers + .find(_.name == ContextHeaders.TrackingIdHeader) + .fold(java.util.UUID.randomUUID.toString)(_.value()) + } + + def extractContextHeaders(ctx: RequestContext): Map[String, String] = { + ctx.request.headers.filter { h => + h.lowercaseName.startsWith("l5d-") || h.name === ContextHeaders.AuthenticationTokenHeader + } map { header => + header.name -> header.value + } toMap + } + + trait Service trait ServiceTransport { - def sendRequest(authToken: AuthToken)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] } trait ServiceDiscovery { @@ -37,25 +61,18 @@ object rest { } class HttpRestServiceTransport(actorSystem: ActorSystem, executionContext: ExecutionContext, - crypto: Crypto, log: Logger, stats: Stats, time: TimeProvider) extends ServiceTransport { + log: Logger, stats: Stats, time: TimeProvider) extends ServiceTransport { protected implicit val materializer = ActorMaterializer()(actorSystem) protected implicit val execution = executionContext - def sendRequest(authToken: AuthToken)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = { + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = { val requestTime = time.currentTime() - val encryptionFlow = Flow[ByteString] map { bytes => - ByteString(crypto.encrypt(crypto.keyForToken(authToken))(bytes.toArray)) - } - val decryptionFlow = Flow[ByteString] map { bytes => - ByteString(crypto.decrypt(crypto.keyForToken(authToken))(bytes.toArray)) - } - val request = (if(requestStub.entity.isKnownEmpty()) requestStub else { - requestStub.withEntity(requestStub.entity.transformDataBytes(encryptionFlow)) - }).withHeaders(RawHeader(AuthService.AuthenticationTokenHeader, authToken.value.value), - RawHeader(AuthService.TrackingIdHeader, authToken.trackingId)) + val request = requestStub + .withHeaders(RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId)) + .withHeaders(context.contextHeaders.toSeq.map { h => RawHeader(h._1, h._2): HttpHeader }: _*) log.audit(s"Sending to ${request.uri} request $request") @@ -65,7 +82,7 @@ object rest { } else if(response.status.isFailure()) { throw new Exception("Http status is failure " + response.status) } else { - Unmarshal(response.entity.transformDataBytes(decryptionFlow)) + Unmarshal(response.entity) } } -- cgit v1.2.3