diff options
Diffstat (limited to 'src/main/scala/xyz/driver/core/rest/package.scala')
-rw-r--r-- | src/main/scala/xyz/driver/core/rest/package.scala | 151 |
1 files changed, 151 insertions, 0 deletions
diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala new file mode 100644 index 0000000..942ca3a --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -0,0 +1,151 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} +import akka.http.scaladsl.model.headers.{HttpOriginRange, Origin, `Access-Control-Allow-Origin`} +import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity, StatusCodes} +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.stream.scaladsl.Flow +import akka.util.ByteString +import xyz.driver.tracing.TracingDirectives + +import scala.concurrent.Future +import scalaz.{Functor, OptionT} +import scalaz.Scalaz.{intInstance, stringInstance} +import scalaz.syntax.equal._ + +trait Service + +trait HttpClient { + def makeRequest(request: HttpRequest): Future[HttpResponse] +} + +trait ServiceTransport { + + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] + + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] +} + +final case class Pagination(pageSize: Int, pageNumber: Int) + +object `package` { + implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { + def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( + implicit F: Functor[Future], + em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { + optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) + } + } + + object ContextHeaders { + val AuthenticationTokenHeader: String = "Authorization" + val PermissionsTokenHeader: String = "Permissions" + val AuthenticationHeaderPrefix: String = "Bearer" + val TrackingIdHeader: String = "X-Trace" + val StacktraceHeader: String = "X-Stacktrace" + val TraceHeaderName: String = TracingDirectives.TraceHeaderName + val SpanHeaderName: String = TracingDirectives.SpanHeaderName + } + + object AuthProvider { + val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader + val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader + val SetAuthenticationTokenHeader: String = "set-authorization" + val SetPermissionsTokenHeader: String = "set-permissions" + } + + val AllowedHeaders: Seq[String] = + Seq( + "Origin", + "X-Requested-With", + "Content-Type", + "Content-Length", + "Accept", + "X-Trace", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Origin", + "Access-Control-Allow-Headers", + "Server", + "Date", + ContextHeaders.TrackingIdHeader, + ContextHeaders.TraceHeaderName, + ContextHeaders.SpanHeaderName, + ContextHeaders.StacktraceHeader, + ContextHeaders.AuthenticationTokenHeader, + "X-Frame-Options", + "X-Content-Type-Options", + "Strict-Transport-Security", + AuthProvider.SetAuthenticationTokenHeader, + AuthProvider.SetPermissionsTokenHeader + ) + + def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` = + `Access-Control-Allow-Origin`( + originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) + + def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) + + def extractServiceContext(request: HttpRequest): ServiceRequestContext = + new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) + + def extractTrackingId(request: HttpRequest): String = { + request.headers + .find(_.name == ContextHeaders.TrackingIdHeader) + .fold(java.util.UUID.randomUUID.toString)(_.value()) + } + + def extractStacktrace(request: HttpRequest): Array[String] = + request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") + + def extractContextHeaders(request: HttpRequest): Map[String, String] = { + request.headers.filter { h => + h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || + h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || + h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName + } map { header => + if (header.name === ContextHeaders.AuthenticationTokenHeader) { + header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim + } else { + header.name -> header.value + } + } toMap + } + + private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { + @annotation.tailrec + def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { + val index = byteString.indexOf('/', from) + if (index === -1) descIndices.reverse + else { + val (init, tail) = byteString.splitAt(index) + if ((init endsWith "<") && (tail startsWith "/sc")) { + dirtyIndices(index + 1, index :: descIndices) + } else { + dirtyIndices(index + 1, descIndices) + } + } + } + + val indices = dirtyIndices(0, Nil) + + indices.headOption.fold(byteString) { head => + val builder = ByteString.newBuilder + builder ++= byteString.take(head) + + (indices :+ byteString.length).sliding(2).foreach { + case Seq(start, end) => + builder += ' ' + builder ++= byteString.slice(start, end) + case Seq(_) => // Should not match; sliding on at least 2 elements + assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") + } + builder.result + } + } + + val sanitizeRequestEntity: Directive0 = { + mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) + } +} |