aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/core/rest/package.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/core/rest/package.scala')
-rw-r--r--src/main/scala/xyz/driver/core/rest/package.scala151
1 files changed, 151 insertions, 0 deletions
diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala
new file mode 100644
index 0000000..942ca3a
--- /dev/null
+++ b/src/main/scala/xyz/driver/core/rest/package.scala
@@ -0,0 +1,151 @@
+package xyz.driver.core.rest
+
+import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable}
+import akka.http.scaladsl.model.headers.{HttpOriginRange, Origin, `Access-Control-Allow-Origin`}
+import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity, StatusCodes}
+import akka.http.scaladsl.server.Directives._
+import akka.http.scaladsl.server._
+import akka.http.scaladsl.unmarshalling.Unmarshal
+import akka.stream.scaladsl.Flow
+import akka.util.ByteString
+import xyz.driver.tracing.TracingDirectives
+
+import scala.concurrent.Future
+import scalaz.{Functor, OptionT}
+import scalaz.Scalaz.{intInstance, stringInstance}
+import scalaz.syntax.equal._
+
+trait Service
+
+trait HttpClient {
+ def makeRequest(request: HttpRequest): Future[HttpResponse]
+}
+
+trait ServiceTransport {
+
+ def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse]
+
+ def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]]
+}
+
+final case class Pagination(pageSize: Int, pageNumber: Int)
+
+object `package` {
+ implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) {
+ def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)(
+ implicit F: Functor[Future],
+ em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = {
+ optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None)
+ }
+ }
+
+ object ContextHeaders {
+ val AuthenticationTokenHeader: String = "Authorization"
+ val PermissionsTokenHeader: String = "Permissions"
+ val AuthenticationHeaderPrefix: String = "Bearer"
+ val TrackingIdHeader: String = "X-Trace"
+ val StacktraceHeader: String = "X-Stacktrace"
+ val TraceHeaderName: String = TracingDirectives.TraceHeaderName
+ val SpanHeaderName: String = TracingDirectives.SpanHeaderName
+ }
+
+ object AuthProvider {
+ val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader
+ val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader
+ val SetAuthenticationTokenHeader: String = "set-authorization"
+ val SetPermissionsTokenHeader: String = "set-permissions"
+ }
+
+ val AllowedHeaders: Seq[String] =
+ Seq(
+ "Origin",
+ "X-Requested-With",
+ "Content-Type",
+ "Content-Length",
+ "Accept",
+ "X-Trace",
+ "Access-Control-Allow-Methods",
+ "Access-Control-Allow-Origin",
+ "Access-Control-Allow-Headers",
+ "Server",
+ "Date",
+ ContextHeaders.TrackingIdHeader,
+ ContextHeaders.TraceHeaderName,
+ ContextHeaders.SpanHeaderName,
+ ContextHeaders.StacktraceHeader,
+ ContextHeaders.AuthenticationTokenHeader,
+ "X-Frame-Options",
+ "X-Content-Type-Options",
+ "Strict-Transport-Security",
+ AuthProvider.SetAuthenticationTokenHeader,
+ AuthProvider.SetPermissionsTokenHeader
+ )
+
+ def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` =
+ `Access-Control-Allow-Origin`(
+ originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*)))
+
+ def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request))
+
+ def extractServiceContext(request: HttpRequest): ServiceRequestContext =
+ new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request))
+
+ def extractTrackingId(request: HttpRequest): String = {
+ request.headers
+ .find(_.name == ContextHeaders.TrackingIdHeader)
+ .fold(java.util.UUID.randomUUID.toString)(_.value())
+ }
+
+ def extractStacktrace(request: HttpRequest): Array[String] =
+ request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->")
+
+ def extractContextHeaders(request: HttpRequest): Map[String, String] = {
+ request.headers.filter { h =>
+ h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader ||
+ h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader ||
+ h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName
+ } map { header =>
+ if (header.name === ContextHeaders.AuthenticationTokenHeader) {
+ header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim
+ } else {
+ header.name -> header.value
+ }
+ } toMap
+ }
+
+ private[rest] def escapeScriptTags(byteString: ByteString): ByteString = {
+ @annotation.tailrec
+ def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = {
+ val index = byteString.indexOf('/', from)
+ if (index === -1) descIndices.reverse
+ else {
+ val (init, tail) = byteString.splitAt(index)
+ if ((init endsWith "<") && (tail startsWith "/sc")) {
+ dirtyIndices(index + 1, index :: descIndices)
+ } else {
+ dirtyIndices(index + 1, descIndices)
+ }
+ }
+ }
+
+ val indices = dirtyIndices(0, Nil)
+
+ indices.headOption.fold(byteString) { head =>
+ val builder = ByteString.newBuilder
+ builder ++= byteString.take(head)
+
+ (indices :+ byteString.length).sliding(2).foreach {
+ case Seq(start, end) =>
+ builder += ' '
+ builder ++= byteString.slice(start, end)
+ case Seq(_) => // Should not match; sliding on at least 2 elements
+ assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices")
+ }
+ builder.result
+ }
+ }
+
+ val sanitizeRequestEntity: Directive0 = {
+ mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags))))
+ }
+}