package xyz.driver.core.rest import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity} import akka.http.scaladsl.server.Directives._ import akka.http.scaladsl.server._ import akka.http.scaladsl.unmarshalling.Unmarshal import akka.stream.scaladsl.Flow import akka.util.ByteString import xyz.driver.tracing.TracingDirectives import scala.concurrent.Future import scalaz.Scalaz.{intInstance, stringInstance} import scalaz.syntax.equal._ trait Service trait HttpClient { def makeRequest(request: HttpRequest): Future[HttpResponse] } trait ServiceTransport { def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] } final case class Pagination(pageSize: Int, pageNumber: Int) object `package` { object ContextHeaders { val AuthenticationTokenHeader: String = "Authorization" val PermissionsTokenHeader: String = "Permissions" val AuthenticationHeaderPrefix: String = "Bearer" val TrackingIdHeader: String = "X-Trace" val StacktraceHeader: String = "X-Stacktrace" val TraceHeaderName: String = TracingDirectives.TraceHeaderName val SpanHeaderName: String = TracingDirectives.SpanHeaderName } object AuthProvider { val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader val SetAuthenticationTokenHeader: String = "set-authorization" val SetPermissionsTokenHeader: String = "set-permissions" } val AllowedHeaders: Seq[String] = Seq( "Origin", "X-Requested-With", "Content-Type", "Content-Length", "Accept", "X-Trace", "Access-Control-Allow-Methods", "Access-Control-Allow-Origin", "Access-Control-Allow-Headers", "Server", "Date", ContextHeaders.TrackingIdHeader, ContextHeaders.TraceHeaderName, ContextHeaders.SpanHeaderName, ContextHeaders.StacktraceHeader, ContextHeaders.AuthenticationTokenHeader, "X-Frame-Options", "X-Content-Type-Options", "Strict-Transport-Security", AuthProvider.SetAuthenticationTokenHeader, AuthProvider.SetPermissionsTokenHeader ) def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) def extractServiceContext(request: HttpRequest): ServiceRequestContext = new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) def extractTrackingId(request: HttpRequest): String = { request.headers .find(_.name == ContextHeaders.TrackingIdHeader) .fold(java.util.UUID.randomUUID.toString)(_.value()) } def extractStacktrace(request: HttpRequest): Array[String] = request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") def extractContextHeaders(request: HttpRequest): Map[String, String] = { request.headers.filter { h => h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName } map { header => if (header.name === ContextHeaders.AuthenticationTokenHeader) { header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim } else { header.name -> header.value } } toMap } private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { @annotation.tailrec def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { val index = byteString.indexOf('/', from) if (index === -1) descIndices.reverse else { val (init, tail) = byteString.splitAt(index) if ((init endsWith "<") && (tail startsWith "/sc")) { dirtyIndices(index + 1, index :: descIndices) } else { dirtyIndices(index + 1, descIndices) } } } val indices = dirtyIndices(0, Nil) indices.headOption.fold(byteString) { head => val builder = ByteString.newBuilder builder ++= byteString.take(head) (indices :+ byteString.length).sliding(2).foreach { case Seq(start, end) => builder += ' ' builder ++= byteString.slice(start, end) case Seq(_) => // Should not match; sliding on at least 2 elements assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") } builder.result } } val sanitizeRequestEntity: Directive0 = { mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) } }