From 80db7d7bec56235d0c4a8414ec624d514ec56663 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Fri, 3 Nov 2017 05:15:55 -0700 Subject: Add originatingIP to ServiceRequestContext --- src/main/scala/xyz/driver/core/rest/package.scala | 31 +++++++++++++++++----- .../driver/core/rest/serviceRequestContext.scala | 14 +++++++--- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala index 531cd8a..278a03d 100644 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -1,8 +1,10 @@ package xyz.driver.core.rest +import java.net.InetAddress + import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} import akka.http.scaladsl.model.headers.{HttpOriginRange, Origin, `Access-Control-Allow-Origin`} -import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity, StatusCodes} +import akka.http.scaladsl.model._ import akka.http.scaladsl.server.Directives._ import akka.http.scaladsl.server._ import akka.http.scaladsl.unmarshalling.Unmarshal @@ -12,6 +14,7 @@ import akka.util.ByteString import xyz.driver.tracing.TracingDirectives import scala.concurrent.Future +import scala.util.Try import scalaz.{Functor, OptionT} import scalaz.Scalaz.{intInstance, stringInstance} import scalaz.syntax.equal._ @@ -47,6 +50,7 @@ object `package` { val AuthenticationHeaderPrefix: String = "Bearer" val TrackingIdHeader: String = "X-Trace" val StacktraceHeader: String = "X-Stacktrace" + val OriginatingIPHeader: String = "X-Forwarded-For" val TraceHeaderName: String = TracingDirectives.TraceHeaderName val SpanHeaderName: String = TracingDirectives.SpanHeaderName } @@ -76,6 +80,7 @@ object `package` { ContextHeaders.SpanHeaderName, ContextHeaders.StacktraceHeader, ContextHeaders.AuthenticationTokenHeader, + ContextHeaders.OriginatingIPHeader, "X-Frame-Options", "X-Content-Type-Options", "Strict-Transport-Security", @@ -87,17 +92,30 @@ object `package` { `Access-Control-Allow-Origin`( originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) - def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) + def serviceContext: Directive1[ServiceRequestContext] = { + extractClientIP flatMap { remoteAddress => + extract(ctx => extractServiceContext(ctx.request, remoteAddress)) + } + } - def extractServiceContext(request: HttpRequest): ServiceRequestContext = - new ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) + def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext = + new ServiceRequestContext(extractTrackingId(request), + extractOriginatingIP(request: HttpRequest, remoteAddress), + extractContextHeaders(request)) def extractTrackingId(request: HttpRequest): String = { request.headers - .find(_.name == ContextHeaders.TrackingIdHeader) + .find(_.name === ContextHeaders.TrackingIdHeader) .fold(java.util.UUID.randomUUID.toString)(_.value()) } + def extractOriginatingIP(request: HttpRequest, remoteAddress: RemoteAddress): Option[InetAddress] = { + request.headers + .find(_.name === ContextHeaders.OriginatingIPHeader) + .flatMap(ipName => Try(InetAddress.getByName(ipName.value)).toOption) + .orElse(remoteAddress.toOption) + } + def extractStacktrace(request: HttpRequest): Array[String] = request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") @@ -105,7 +123,8 @@ object `package` { request.headers.filter { h => h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || - h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName + h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName || + h.name === ContextHeaders.OriginatingIPHeader } map { header => if (header.name === ContextHeaders.AuthenticationTokenHeader) { header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim diff --git a/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala b/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala index 4020d57..bd8f078 100644 --- a/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala +++ b/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala @@ -1,5 +1,7 @@ package xyz.driver.core.rest +import java.net.InetAddress + import xyz.driver.core.auth.{AuthToken, PermissionsToken, User} import xyz.driver.core.generators @@ -7,6 +9,7 @@ import scalaz.Scalaz.{mapEqual, stringInstance} import scalaz.syntax.equal._ class ServiceRequestContext(val trackingId: String = generators.nextUuid().toString, + val originatingIP: Option[InetAddress] = None, val contextHeaders: Map[String, String] = Map.empty[String, String]) { def authToken: Option[AuthToken] = contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) @@ -17,28 +20,32 @@ class ServiceRequestContext(val trackingId: String = generators.nextUuid().toStr def withAuthToken(authToken: AuthToken): ServiceRequestContext = new ServiceRequestContext( trackingId, + originatingIP, contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) ) def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = new AuthorizedServiceRequestContext( trackingId, + originatingIP, contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), user ) override def hashCode(): Int = - Seq[Any](trackingId, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) + Seq[Any](trackingId, originatingIP, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) override def equals(obj: Any): Boolean = obj match { - case ctx: ServiceRequestContext => trackingId === ctx.trackingId && contextHeaders === ctx.contextHeaders - case _ => false + case ctx: ServiceRequestContext => + trackingId === ctx.trackingId && originatingIP == originatingIP && contextHeaders === ctx.contextHeaders + case _ => false } override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" } class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString, + override val originatingIP: Option[InetAddress] = None, override val contextHeaders: Map[String, String] = Map.empty[String, String], val authenticatedUser: U) @@ -47,6 +54,7 @@ class AuthorizedServiceRequestContext[U <: User](override val trackingId: String def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = new AuthorizedServiceRequestContext[U]( trackingId, + originatingIP, contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), authenticatedUser) -- cgit v1.2.3 From abedf95d692885fa303c93cd0af5798b5cf82503 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Mon, 6 Nov 2017 11:23:46 -0800 Subject: Remove unnecessary type annotation --- src/main/scala/xyz/driver/core/rest/package.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala index 278a03d..012600e 100644 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -100,7 +100,7 @@ object `package` { def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext = new ServiceRequestContext(extractTrackingId(request), - extractOriginatingIP(request: HttpRequest, remoteAddress), + extractOriginatingIP(request, remoteAddress), extractContextHeaders(request)) def extractTrackingId(request: HttpRequest): String = { -- cgit v1.2.3 From 07c17c88bf9c71f38f6588cb4a49b8d4d9d70050 Mon Sep 17 00:00:00 2001 From: Zach Smith Date: Wed, 8 Nov 2017 13:25:05 -0800 Subject: OriginatingIP -> OriginatingIp --- src/main/scala/xyz/driver/core/rest/package.scala | 8 ++++---- .../scala/xyz/driver/core/rest/serviceRequestContext.scala | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala index 012600e..3697d93 100644 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -50,7 +50,7 @@ object `package` { val AuthenticationHeaderPrefix: String = "Bearer" val TrackingIdHeader: String = "X-Trace" val StacktraceHeader: String = "X-Stacktrace" - val OriginatingIPHeader: String = "X-Forwarded-For" + val OriginatingIpHeader: String = "X-Forwarded-For" val TraceHeaderName: String = TracingDirectives.TraceHeaderName val SpanHeaderName: String = TracingDirectives.SpanHeaderName } @@ -80,7 +80,7 @@ object `package` { ContextHeaders.SpanHeaderName, ContextHeaders.StacktraceHeader, ContextHeaders.AuthenticationTokenHeader, - ContextHeaders.OriginatingIPHeader, + ContextHeaders.OriginatingIpHeader, "X-Frame-Options", "X-Content-Type-Options", "Strict-Transport-Security", @@ -111,7 +111,7 @@ object `package` { def extractOriginatingIP(request: HttpRequest, remoteAddress: RemoteAddress): Option[InetAddress] = { request.headers - .find(_.name === ContextHeaders.OriginatingIPHeader) + .find(_.name === ContextHeaders.OriginatingIpHeader) .flatMap(ipName => Try(InetAddress.getByName(ipName.value)).toOption) .orElse(remoteAddress.toOption) } @@ -124,7 +124,7 @@ object `package` { h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName || - h.name === ContextHeaders.OriginatingIPHeader + h.name === ContextHeaders.OriginatingIpHeader } map { header => if (header.name === ContextHeaders.AuthenticationTokenHeader) { header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim diff --git a/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala b/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala index bd8f078..58ff1f1 100644 --- a/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala +++ b/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala @@ -9,7 +9,7 @@ import scalaz.Scalaz.{mapEqual, stringInstance} import scalaz.syntax.equal._ class ServiceRequestContext(val trackingId: String = generators.nextUuid().toString, - val originatingIP: Option[InetAddress] = None, + val originatingIp: Option[InetAddress] = None, val contextHeaders: Map[String, String] = Map.empty[String, String]) { def authToken: Option[AuthToken] = contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) @@ -20,24 +20,24 @@ class ServiceRequestContext(val trackingId: String = generators.nextUuid().toStr def withAuthToken(authToken: AuthToken): ServiceRequestContext = new ServiceRequestContext( trackingId, - originatingIP, + originatingIp, contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) ) def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = new AuthorizedServiceRequestContext( trackingId, - originatingIP, + originatingIp, contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), user ) override def hashCode(): Int = - Seq[Any](trackingId, originatingIP, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) + Seq[Any](trackingId, originatingIp, contextHeaders).foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) override def equals(obj: Any): Boolean = obj match { case ctx: ServiceRequestContext => - trackingId === ctx.trackingId && originatingIP == originatingIP && contextHeaders === ctx.contextHeaders + trackingId === ctx.trackingId && originatingIp == originatingIp && contextHeaders === ctx.contextHeaders case _ => false } @@ -45,7 +45,7 @@ class ServiceRequestContext(val trackingId: String = generators.nextUuid().toStr } class AuthorizedServiceRequestContext[U <: User](override val trackingId: String = generators.nextUuid().toString, - override val originatingIP: Option[InetAddress] = None, + override val originatingIp: Option[InetAddress] = None, override val contextHeaders: Map[String, String] = Map.empty[String, String], val authenticatedUser: U) @@ -54,7 +54,7 @@ class AuthorizedServiceRequestContext[U <: User](override val trackingId: String def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = new AuthorizedServiceRequestContext[U]( trackingId, - originatingIP, + originatingIp, contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), authenticatedUser) -- cgit v1.2.3