aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/core/rest.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/core/rest.scala')
-rw-r--r--src/main/scala/xyz/driver/core/rest.scala153
1 files changed, 153 insertions, 0 deletions
diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala
new file mode 100644
index 0000000..f05a800
--- /dev/null
+++ b/src/main/scala/xyz/driver/core/rest.scala
@@ -0,0 +1,153 @@
+package xyz.driver.core
+
+import akka.actor.ActorSystem
+import akka.http.scaladsl.Http
+import akka.http.scaladsl.model._
+import akka.http.scaladsl.model.headers.RawHeader
+import akka.http.scaladsl.unmarshalling.Unmarshal
+import akka.stream.ActorMaterializer
+import com.github.swagger.akka.model._
+import com.github.swagger.akka.{HasActorSystem, SwaggerHttpService}
+import com.typesafe.config.Config
+import xyz.driver.core.logging.Logger
+import xyz.driver.core.stats.Stats
+import xyz.driver.core.time.TimeRange
+import xyz.driver.core.time.provider.TimeProvider
+
+import scala.concurrent.{ExecutionContext, Future}
+import scala.util.{Failure, Success}
+import scalaz.Scalaz.{Id => _, _}
+
+object rest {
+
+ object ContextHeaders {
+ val AuthenticationTokenHeader = "WWW-Authenticate"
+ val TrackingIdHeader = "X-Trace"
+
+ object LinkerD {
+ // https://linkerd.io/doc/0.7.4/linkerd/protocol-http/
+ def isLinkerD(headerName: String) = headerName.startsWith("l5d-")
+ }
+ }
+
+ final case class ServiceRequestContext(
+ trackingId: String = generators.nextUuid().toString,
+ contextHeaders: Map[String, String] = Map.empty[String, String])
+
+ import akka.http.scaladsl.server._
+ import Directives._
+
+ def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx))
+
+ def extractServiceContext(ctx: RequestContext): ServiceRequestContext =
+ ServiceRequestContext(extractTrackingId(ctx), extractContextHeaders(ctx))
+
+ def extractTrackingId(ctx: RequestContext): String = {
+ ctx.request.headers
+ .find(_.name == ContextHeaders.TrackingIdHeader)
+ .fold(java.util.UUID.randomUUID.toString)(_.value())
+ }
+
+ def extractContextHeaders(ctx: RequestContext): Map[String, String] = {
+ ctx.request.headers.filter { h =>
+ h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader
+ // || ContextHeaders.LinkerD.isLinkerD(h.lowercaseName)
+ } map { header =>
+ header.name -> header.value
+ } toMap
+ }
+
+
+ trait Service
+
+ trait ServiceTransport {
+
+ def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]]
+ }
+
+ trait ServiceDiscovery {
+
+ def discover[T <: Service](serviceName: Name[Service]): T
+ }
+
+ class HttpRestServiceTransport(actorSystem: ActorSystem, executionContext: ExecutionContext,
+ log: Logger, stats: Stats, time: TimeProvider) extends ServiceTransport {
+
+ protected implicit val materializer = ActorMaterializer()(actorSystem)
+ protected implicit val execution = executionContext
+
+ def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = {
+
+ val requestTime = time.currentTime()
+
+ val request = requestStub
+ .withHeaders(RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId))
+ .withHeaders(context.contextHeaders.toSeq.map { h => RawHeader(h._1, h._2): HttpHeader }: _*)
+
+ log.audit(s"Sending to ${request.uri} request $request with tracking id ${context.trackingId}")
+
+ val responseEntity = Http()(actorSystem).singleRequest(request)(materializer) map { response =>
+ if(response.status == StatusCodes.NotFound) {
+ Unmarshal(HttpEntity.Empty: ResponseEntity)
+ } else if(response.status.isFailure()) {
+ throw new Exception(s"Http status is failure ${response.status}")
+ } else {
+ Unmarshal(response.entity)
+ }
+ }
+
+ responseEntity.onComplete {
+ case Success(r) =>
+ val responseTime = time.currentTime()
+ log.audit(s"Response from ${request.uri} to request $requestStub is successful")
+ stats.recordStats(Seq("request", request.uri.toString, "success"), TimeRange(requestTime, responseTime), 1)
+
+ case Failure(t: Throwable) =>
+ val responseTime = time.currentTime()
+ log.audit(s"Failed to receive response from ${request.uri} to request $requestStub")
+ log.error(s"Failed to receive response from ${request.uri} to request $requestStub", t)
+ stats.recordStats(Seq("request", request.uri.toString, "fail"), TimeRange(requestTime, responseTime), 1)
+ } (executionContext)
+
+ responseEntity
+ }
+ }
+
+ import scala.reflect.runtime.universe._
+
+ class Swagger(override val host: String,
+ version: String,
+ override val actorSystem: ActorSystem,
+ override val apiTypes: Seq[Type],
+ val config: Config) extends SwaggerHttpService with HasActorSystem {
+
+ val materializer = ActorMaterializer()(actorSystem)
+
+ override val basePath = config.getString("swagger.basePath")
+ override val apiDocsPath = config.getString("swagger.docsPath")
+
+ override val info = Info(
+ config.getString("swagger.apiInfo.description"),
+ version,
+ config.getString("swagger.apiInfo.title"),
+ config.getString("swagger.apiInfo.termsOfServiceUrl"),
+ contact = Some(Contact(
+ config.getString("swagger.apiInfo.contact.name"),
+ config.getString("swagger.apiInfo.contact.url"),
+ config.getString("swagger.apiInfo.contact.email")
+ )),
+ license = Some(License(
+ config.getString("swagger.apiInfo.license"),
+ config.getString("swagger.apiInfo.licenseUrl")
+ )),
+ vendorExtensions = Map.empty[String, AnyRef])
+
+ def swaggerUI = get {
+ pathPrefix("") {
+ pathEndOrSingleSlash {
+ getFromResource("swagger-ui/index.html")
+ }
+ } ~ getFromResourceDirectory("swagger-ui")
+ }
+ }
+}