aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/core/rest/package.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/xyz/driver/core/rest/package.scala')
-rw-r--r--src/main/scala/xyz/driver/core/rest/package.scala31
1 files changed, 25 insertions, 6 deletions
diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala
index 531cd8a..3697d93 100644
--- a/src/main/scala/xyz/driver/core/rest/package.scala
+++ b/src/main/scala/xyz/driver/core/rest/package.scala
@@ -1,8 +1,10 @@
package xyz.driver.core.rest
+import java.net.InetAddress
+
import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable}
import akka.http.scaladsl.model.headers.{HttpOriginRange, Origin, `Access-Control-Allow-Origin`}
-import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity, StatusCodes}
+import akka.http.scaladsl.model._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server._
import akka.http.scaladsl.unmarshalling.Unmarshal
@@ -12,6 +14,7 @@ import akka.util.ByteString
import xyz.driver.tracing.TracingDirectives
import scala.concurrent.Future
+import scala.util.Try
import scalaz.{Functor, OptionT}
import scalaz.Scalaz.{intInstance, stringInstance}
import scalaz.syntax.equal._
@@ -47,6 +50,7 @@ object `package` {
val AuthenticationHeaderPrefix: String = "Bearer"
val TrackingIdHeader: String = "X-Trace"
val StacktraceHeader: String = "X-Stacktrace"
+ val OriginatingIpHeader: String = "X-Forwarded-For"
val TraceHeaderName: String = TracingDirectives.TraceHeaderName
val SpanHeaderName: String = TracingDirectives.SpanHeaderName
}
@@ -76,6 +80,7 @@ object `package` {
ContextHeaders.SpanHeaderName,
ContextHeaders.StacktraceHeader,
ContextHeaders.AuthenticationTokenHeader,
+ ContextHeaders.OriginatingIpHeader,
"X-Frame-Options",
"X-Content-Type-Options",
"Strict-Transport-Security",
@@ -87,17 +92,30 @@ object `package` {
`Access-Control-Allow-Origin`(
originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*)))
- def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request))
+ def serviceContext: Directive1[ServiceRequestContext] = {
+ extractClientIP flatMap { remoteAddress =>
+ extract(ctx => extractServiceContext(ctx.request, remoteAddress))
+ }
+ }
- def extractServiceContext(request: HttpRequest): ServiceRequestContext =
- new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request))
+ def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext =
+ new ServiceRequestContext(extractTrackingId(request),
+ extractOriginatingIP(request, remoteAddress),
+ extractContextHeaders(request))
def extractTrackingId(request: HttpRequest): String = {
request.headers
- .find(_.name == ContextHeaders.TrackingIdHeader)
+ .find(_.name === ContextHeaders.TrackingIdHeader)
.fold(java.util.UUID.randomUUID.toString)(_.value())
}
+ def extractOriginatingIP(request: HttpRequest, remoteAddress: RemoteAddress): Option[InetAddress] = {
+ request.headers
+ .find(_.name === ContextHeaders.OriginatingIpHeader)
+ .flatMap(ipName => Try(InetAddress.getByName(ipName.value)).toOption)
+ .orElse(remoteAddress.toOption)
+ }
+
def extractStacktrace(request: HttpRequest): Array[String] =
request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->")
@@ -105,7 +123,8 @@ object `package` {
request.headers.filter { h =>
h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader ||
h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader ||
- h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName
+ h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName ||
+ h.name === ContextHeaders.OriginatingIpHeader
} map { header =>
if (header.name === ContextHeaders.AuthenticationTokenHeader) {
header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim