aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzachdriver <zach@driver.xyz>2017-11-09 11:29:47 -0800
committerGitHub <noreply@github.com>2017-11-09 11:29:47 -0800
commit523cc3ff8c3ee8f0e968d2a193bf427a8ef772bf (patch)
tree83fa33394ceb3a8ba93f6f28eb5a19d79ae481c8
parenta996353504ffb50d352f8a0fc69681333ef695b4 (diff)
parent07c17c88bf9c71f38f6588cb4a49b8d4d9d70050 (diff)
downloaddriver-core-523cc3ff8c3ee8f0e968d2a193bf427a8ef772bf.tar.gz
driver-core-523cc3ff8c3ee8f0e968d2a193bf427a8ef772bf.tar.bz2
driver-core-523cc3ff8c3ee8f0e968d2a193bf427a8ef772bf.zip
Merge pull request #85 from drivergroup/zsmith/login-auditsv1.6.6
Add originatingIP to ServiceRequestContext
-rw-r--r--src/main/scala/xyz/driver/core/rest/package.scala31
-rw-r--r--src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala14
2 files changed, 36 insertions, 9 deletions
diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala
index 531cd8a..3697d93 100644
--- a/src/main/scala/xyz/driver/core/rest/package.scala
+++ b/src/main/scala/xyz/driver/core/rest/package.scala
@@ -1,8 +1,10 @@
package xyz.driver.core.rest
+import java.net.InetAddress
+
import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable}
import akka.http.scaladsl.model.headers.{HttpOriginRange, Origin, `Access-Control-Allow-Origin`}
-import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity, StatusCodes}
+import akka.http.scaladsl.model._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server._
import akka.http.scaladsl.unmarshalling.Unmarshal
@@ -12,6 +14,7 @@ import akka.util.ByteString
import xyz.driver.tracing.TracingDirectives
import scala.concurrent.Future
+import scala.util.Try
import scalaz.{Functor, OptionT}
import scalaz.Scalaz.{intInstance, stringInstance}
import scalaz.syntax.equal._
@@ -47,6 +50,7 @@ object `package` {
val AuthenticationHeaderPrefix: String = "Bearer"
val TrackingIdHeader: String = "X-Trace"
val StacktraceHeader: String = "X-Stacktrace"
+ val OriginatingIpHeader: String = "X-Forwarded-For"
val TraceHeaderName: String = TracingDirectives.TraceHeaderName
val SpanHeaderName: String = TracingDirectives.SpanHeaderName
}
@@ -76,6 +80,7 @@ object `package` {
ContextHeaders.SpanHeaderName,
ContextHeaders.StacktraceHeader,
ContextHeaders.AuthenticationTokenHeader,
+ ContextHeaders.OriginatingIpHeader,
"X-Frame-Options",
"X-Content-Type-Options",
"Strict-Transport-Security",
@@ -87,17 +92,30 @@ object `package` {
`Access-Control-Allow-Origin`(
originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*)))
- def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request))
+ def serviceContext: Directive1[ServiceRequestContext] = {
+ extractClientIP flatMap { remoteAddress =>
+ extract(ctx => extractServiceContext(ctx.request, remoteAddress))
+ }
+ }
- def extractServiceContext(request: HttpRequest): ServiceRequestContext =
- new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request))
+ def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext =
+ new ServiceRequestContext(extractTrackingId(request),
+ extractOriginatingIP(request, remoteAddress),
+ extractContextHeaders(request))
def extractTrackingId(request: HttpRequest): String = {
request.headers
- .find(_.name == ContextHeaders.TrackingIdHeader)
+ .find(_.name === ContextHeaders.TrackingIdHeader)
.fold(java.util.UUID.randomUUID.toString)(_.value())
}
+ def extractOriginatingIP(request: HttpRequest, remoteAddress: RemoteAddress): Option[InetAddress] = {
+ request.headers
+ .find(_.name === ContextHeaders.OriginatingIpHeader)
+ .flatMap(ipName => Try(InetAddress.getByName(ipName.value)).toOption)
+ .orElse(remoteAddress.toOption)
+ }
+
def extractStacktrace(request: HttpRequest): Array[String] =
request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->")
@@ -105,7 +123,8 @@ object `package` {
request.headers.filter { h =>
h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader ||
h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader ||
- h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName
+ h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName ||
+ h.name === ContextHeaders.OriginatingIpHeader
} map { header =>
if (header.name === ContextHeaders.AuthenticationTokenHeader) {
header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim
diff --git a/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala b/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala
index 4020d57..58ff1f1 100644
--- a/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala
+++ b/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala
@@ -1,5 +1,7 @@
package xyz.driver.core.rest
+import java.net.InetAddress
+
import xyz.driver.core.auth.{AuthToken, PermissionsToken, User}
import xyz.driver.core.generators
@@ -7,6 +9,7 @@ import scalaz.Scalaz.{mapEqual, stringInstance}
import scalaz.syntax.equal._
class ServiceRequestContext(val trackingId: String = generators.nextUuid().toString,
+ val originatingIp: Option[InetAddress] = None,
val contextHeaders: Map[String, String] = Map.empty[String, String]) {
def authToken: Option[AuthToken] =
contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply)
@@ -17,28 +20,32 @@ class ServiceRequestContext(val trackingId: String = generators.nextUuid().toStr
def withAuthToken(authToken: AuthToken): ServiceRequestContext =
new ServiceRequestContext(
trackingId,
+ originatingIp,
contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value)
)
def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] =
new AuthorizedServiceRequestContext(
trackingId,
+ originatingIp,
contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value),
user
)
override def hashCode(): Int =
- Seq[Any](trackingId, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode())
+ Seq[Any](trackingId, originatingIp, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode())
override def equals(obj: Any): Boolean = obj match {
- case ctx: ServiceRequestContext => trackingId === ctx.trackingId && contextHeaders === ctx.contextHeaders
- case _ => false
+ case ctx: ServiceRequestContext =>
+ trackingId === ctx.trackingId && originatingIp == originatingIp && contextHeaders === ctx.contextHeaders
+ case _ => false
}
override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)"
}
class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString,
+ override val originatingIp: Option[InetAddress] = None,
override val contextHeaders: Map[String, String] =
Map.empty[String, String],
val authenticatedUser: U)
@@ -47,6 +54,7 @@ class AuthorizedServiceRequestContext[U <: User](override val trackingId: String
def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] =
new AuthorizedServiceRequestContext[U](
trackingId,
+ originatingIp,
contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value),
authenticatedUser)