aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/core/rest/package.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/core/rest/package.scala')
-rw-r--r--src/main/scala/xyz/driver/core/rest/package.scala92
1 files changed, 92 insertions, 0 deletions
diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala
new file mode 100644
index 0000000..4c8e13c
--- /dev/null
+++ b/src/main/scala/xyz/driver/core/rest/package.scala
@@ -0,0 +1,92 @@
+package xyz.driver.core
+
+import akka.http.scaladsl.model.HttpRequest
+import akka.http.scaladsl.server.Directives._
+import akka.http.scaladsl.server._
+import akka.stream.scaladsl.Flow
+import akka.util.ByteString
+
+import scalaz.Scalaz.{intInstance, stringInstance}
+import scalaz.syntax.equal._
+
+package object rest {
+ object ContextHeaders {
+ val AuthenticationTokenHeader = "Authorization"
+ val PermissionsTokenHeader = "Permissions"
+ val AuthenticationHeaderPrefix = "Bearer"
+ val TrackingIdHeader = "X-Trace"
+ val StacktraceHeader = "X-Stacktrace"
+ val TracingHeader = trace.TracingHeaderKey
+ }
+
+ object AuthProvider {
+ val AuthenticationTokenHeader = ContextHeaders.AuthenticationTokenHeader
+ val PermissionsTokenHeader = ContextHeaders.PermissionsTokenHeader
+ val SetAuthenticationTokenHeader = "set-authorization"
+ val SetPermissionsTokenHeader = "set-permissions"
+ }
+
+ def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request))
+
+ def extractServiceContext(request: HttpRequest): ServiceRequestContext =
+ new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request))
+
+ def extractTrackingId(request: HttpRequest): String = {
+ request.headers
+ .find(_.name == ContextHeaders.TrackingIdHeader)
+ .fold(java.util.UUID.randomUUID.toString)(_.value())
+ }
+
+ def extractStacktrace(request: HttpRequest): Array[String] =
+ request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->")
+
+ def extractContextHeaders(request: HttpRequest): Map[String, String] = {
+ request.headers.filter { h =>
+ h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader ||
+ h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader ||
+ h.name === ContextHeaders.TracingHeader
+ } map { header =>
+ if (header.name === ContextHeaders.AuthenticationTokenHeader) {
+ header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim
+ } else {
+ header.name -> header.value
+ }
+ } toMap
+ }
+
+ private[rest] def escapeScriptTags(byteString: ByteString): ByteString = {
+ @annotation.tailrec
+ def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = {
+ val index = byteString.indexOf('/', from)
+ if (index === -1) descIndices.reverse
+ else {
+ val (init, tail) = byteString.splitAt(index)
+ if ((init endsWith "<") && (tail startsWith "/sc")) {
+ dirtyIndices(index + 1, index :: descIndices)
+ } else {
+ dirtyIndices(index + 1, descIndices)
+ }
+ }
+ }
+
+ val indices = dirtyIndices(0, Nil)
+
+ indices.headOption.fold(byteString) { head =>
+ val builder = ByteString.newBuilder
+ builder ++= byteString.take(head)
+
+ (indices :+ byteString.length).sliding(2).foreach {
+ case Seq(start, end) =>
+ builder += ' '
+ builder ++= byteString.slice(start, end)
+ case Seq(_) => // Should not match; sliding on at least 2 elements
+ assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices")
+ }
+ builder.result
+ }
+ }
+
+ val sanitizeRequestEntity: Directive0 = {
+ mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags))))
+ }
+}