1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
|
package xyz.driver.core.rest
import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable}
import akka.http.scaladsl.model.headers.{HttpOriginRange, Origin, `Access-Control-Allow-Origin`}
import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity, StatusCodes}
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server._
import akka.http.scaladsl.unmarshalling.Unmarshal
import akka.stream.scaladsl.Flow
import akka.util.ByteString
import xyz.driver.tracing.TracingDirectives
import scala.concurrent.Future
import scalaz.{Functor, OptionT}
import scalaz.Scalaz.{intInstance, stringInstance}
import scalaz.syntax.equal._
trait Service
trait HttpClient {
def makeRequest(request: HttpRequest): Future[HttpResponse]
}
trait ServiceTransport {
def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse]
def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]]
}
final case class Pagination(pageSize: Int, pageNumber: Int)
object `package` {
implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) {
def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)(
implicit F: Functor[Future],
em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = {
optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None)
}
}
object ContextHeaders {
val AuthenticationTokenHeader: String = "Authorization"
val PermissionsTokenHeader: String = "Permissions"
val AuthenticationHeaderPrefix: String = "Bearer"
val TrackingIdHeader: String = "X-Trace"
val StacktraceHeader: String = "X-Stacktrace"
val TraceHeaderName: String = TracingDirectives.TraceHeaderName
val SpanHeaderName: String = TracingDirectives.SpanHeaderName
}
object AuthProvider {
val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader
val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader
val SetAuthenticationTokenHeader: String = "set-authorization"
val SetPermissionsTokenHeader: String = "set-permissions"
}
val AllowedHeaders: Seq[String] =
Seq(
"Origin",
"X-Requested-With",
"Content-Type",
"Content-Length",
"Accept",
"X-Trace",
"Access-Control-Allow-Methods",
"Access-Control-Allow-Origin",
"Access-Control-Allow-Headers",
"Server",
"Date",
ContextHeaders.TrackingIdHeader,
ContextHeaders.TraceHeaderName,
ContextHeaders.SpanHeaderName,
ContextHeaders.StacktraceHeader,
ContextHeaders.AuthenticationTokenHeader,
"X-Frame-Options",
"X-Content-Type-Options",
"Strict-Transport-Security",
AuthProvider.SetAuthenticationTokenHeader,
AuthProvider.SetPermissionsTokenHeader
)
def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` =
`Access-Control-Allow-Origin`(
originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*)))
def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request))
def extractServiceContext(request: HttpRequest): ServiceRequestContext =
new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request))
def extractTrackingId(request: HttpRequest): String = {
request.headers
.find(_.name == ContextHeaders.TrackingIdHeader)
.fold(java.util.UUID.randomUUID.toString)(_.value())
}
def extractStacktrace(request: HttpRequest): Array[String] =
request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->")
def extractContextHeaders(request: HttpRequest): Map[String, String] = {
request.headers.filter { h =>
h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader ||
h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader ||
h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName
} map { header =>
if (header.name === ContextHeaders.AuthenticationTokenHeader) {
header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim
} else {
header.name -> header.value
}
} toMap
}
private[rest] def escapeScriptTags(byteString: ByteString): ByteString = {
@annotation.tailrec
def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = {
val index = byteString.indexOf('/', from)
if (index === -1) descIndices.reverse
else {
val (init, tail) = byteString.splitAt(index)
if ((init endsWith "<") && (tail startsWith "/sc")) {
dirtyIndices(index + 1, index :: descIndices)
} else {
dirtyIndices(index + 1, descIndices)
}
}
}
val indices = dirtyIndices(0, Nil)
indices.headOption.fold(byteString) { head =>
val builder = ByteString.newBuilder
builder ++= byteString.take(head)
(indices :+ byteString.length).sliding(2).foreach {
case Seq(start, end) =>
builder += ' '
builder ++= byteString.slice(start, end)
case Seq(_) => // Should not match; sliding on at least 2 elements
assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices")
}
builder.result
}
}
val sanitizeRequestEntity: Directive0 = {
mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags))))
}
}
|