diff options
Diffstat (limited to 'src/main/scala/xyz/driver/core/rest/package.scala')
-rw-r--r-- | src/main/scala/xyz/driver/core/rest/package.scala | 92 |
1 files changed, 92 insertions, 0 deletions
diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala new file mode 100644 index 0000000..4c8e13c --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -0,0 +1,92 @@ +package xyz.driver.core + +import akka.http.scaladsl.model.HttpRequest +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import akka.stream.scaladsl.Flow +import akka.util.ByteString + +import scalaz.Scalaz.{intInstance, stringInstance} +import scalaz.syntax.equal._ + +package object rest { + object ContextHeaders { + val AuthenticationTokenHeader = "Authorization" + val PermissionsTokenHeader = "Permissions" + val AuthenticationHeaderPrefix = "Bearer" + val TrackingIdHeader = "X-Trace" + val StacktraceHeader = "X-Stacktrace" + val TracingHeader = trace.TracingHeaderKey + } + + object AuthProvider { + val AuthenticationTokenHeader = ContextHeaders.AuthenticationTokenHeader + val PermissionsTokenHeader = ContextHeaders.PermissionsTokenHeader + val SetAuthenticationTokenHeader = "set-authorization" + val SetPermissionsTokenHeader = "set-permissions" + } + + def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) + + def extractServiceContext(request: HttpRequest): ServiceRequestContext = + new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) + + def extractTrackingId(request: HttpRequest): String = { + request.headers + .find(_.name == ContextHeaders.TrackingIdHeader) + .fold(java.util.UUID.randomUUID.toString)(_.value()) + } + + def extractStacktrace(request: HttpRequest): Array[String] = + request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") + + def extractContextHeaders(request: HttpRequest): Map[String, String] = { + request.headers.filter { h => + h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || + h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || + h.name === ContextHeaders.TracingHeader + } map { header => + if (header.name === ContextHeaders.AuthenticationTokenHeader) { + header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim + } else { + header.name -> header.value + } + } toMap + } + + private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { + @annotation.tailrec + def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { + val index = byteString.indexOf('/', from) + if (index === -1) descIndices.reverse + else { + val (init, tail) = byteString.splitAt(index) + if ((init endsWith "<") && (tail startsWith "/sc")) { + dirtyIndices(index + 1, index :: descIndices) + } else { + dirtyIndices(index + 1, descIndices) + } + } + } + + val indices = dirtyIndices(0, Nil) + + indices.headOption.fold(byteString) { head => + val builder = ByteString.newBuilder + builder ++= byteString.take(head) + + (indices :+ byteString.length).sliding(2).foreach { + case Seq(start, end) => + builder += ' ' + builder ++= byteString.slice(start, end) + case Seq(_) => // Should not match; sliding on at least 2 elements + assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") + } + builder.result + } + } + + val sanitizeRequestEntity: Directive0 = { + mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) + } +} |