1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
package xyz.driver.core
import akka.http.scaladsl.model.HttpRequest
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server._
import akka.stream.scaladsl.Flow
import akka.util.ByteString
import xyz.driver.tracing.TracingDirectives
import scalaz.Scalaz.{intInstance, stringInstance}
import scalaz.syntax.equal._
package object rest {
object ContextHeaders {
val AuthenticationTokenHeader: String = "Authorization"
val PermissionsTokenHeader: String = "Permissions"
val AuthenticationHeaderPrefix: String = "Bearer"
val TrackingIdHeader: String = "X-Trace"
val StacktraceHeader: String = "X-Stacktrace"
val TraceHeaderName: String = TracingDirectives.TraceHeaderName
val SpanHeaderName: String = TracingDirectives.SpanHeaderName
}
object AuthProvider {
val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader
val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader
val SetAuthenticationTokenHeader: String = "set-authorization"
val SetPermissionsTokenHeader: String = "set-permissions"
}
val AllowedHeaders: Seq[String] =
Seq(
"Origin",
"X-Requested-With",
"Content-Type",
"Content-Length",
"Accept",
"X-Trace",
"Access-Control-Allow-Methods",
"Access-Control-Allow-Origin",
"Access-Control-Allow-Headers",
"Server",
"Date",
ContextHeaders.TrackingIdHeader,
ContextHeaders.TraceHeaderName,
ContextHeaders.SpanHeaderName,
ContextHeaders.StacktraceHeader,
ContextHeaders.AuthenticationTokenHeader,
"X-Frame-Options",
"X-Content-Type-Options",
"Strict-Transport-Security",
AuthProvider.SetAuthenticationTokenHeader,
AuthProvider.SetPermissionsTokenHeader
)
def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request))
def extractServiceContext(request: HttpRequest): ServiceRequestContext =
new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request))
def extractTrackingId(request: HttpRequest): String = {
request.headers
.find(_.name == ContextHeaders.TrackingIdHeader)
.fold(java.util.UUID.randomUUID.toString)(_.value())
}
def extractStacktrace(request: HttpRequest): Array[String] =
request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->")
def extractContextHeaders(request: HttpRequest): Map[String, String] = {
request.headers.filter { h =>
h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader ||
h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader ||
h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName
} map { header =>
if (header.name === ContextHeaders.AuthenticationTokenHeader) {
header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim
} else {
header.name -> header.value
}
} toMap
}
private[rest] def escapeScriptTags(byteString: ByteString): ByteString = {
@annotation.tailrec
def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = {
val index = byteString.indexOf('/', from)
if (index === -1) descIndices.reverse
else {
val (init, tail) = byteString.splitAt(index)
if ((init endsWith "<") && (tail startsWith "/sc")) {
dirtyIndices(index + 1, index :: descIndices)
} else {
dirtyIndices(index + 1, descIndices)
}
}
}
val indices = dirtyIndices(0, Nil)
indices.headOption.fold(byteString) { head =>
val builder = ByteString.newBuilder
builder ++= byteString.take(head)
(indices :+ byteString.length).sliding(2).foreach {
case Seq(start, end) =>
builder += ' '
builder ++= byteString.slice(start, end)
case Seq(_) => // Should not match; sliding on at least 2 elements
assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices")
}
builder.result
}
}
val sanitizeRequestEntity: Directive0 = {
mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags))))
}
}
|