From eba3d78fd8533703925c7f4d3550ad0c80bbc572 Mon Sep 17 00:00:00 2001 From: Stewart Stewart Date: Mon, 20 Mar 2017 13:50:13 -0400 Subject: turn object xyz.driver.core.rest into package --- src/main/scala/xyz/driver/core/rest.scala | 116 +++++++++++++++--------------- 1 file changed, 59 insertions(+), 57 deletions(-) (limited to 'src') diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index 17837e6..7c9e1d4 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -26,79 +26,81 @@ import scala.util.{Failure, Success} import scalaz.Scalaz.{Id => _, _} import scalaz.{ListT, OptionT} -object rest { +package rest { - final case class ServiceRequestContext(trackingId: String = generators.nextUuid().toString, - contextHeaders: Map[String, String] = Map.empty[String, String]) { - - def authToken: Option[AuthToken] = - contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) - } - - object ContextHeaders { - val AuthenticationTokenHeader = "Authorization" - val AuthenticationHeaderPrefix = "Bearer" - val TrackingIdHeader = "X-Trace" - } + object `package` { + import akka.http.scaladsl.server._ + import Directives._ - import akka.http.scaladsl.server._ - import Directives._ + def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) - def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) + def extractServiceContext(request: HttpRequest): ServiceRequestContext = + ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) - def extractServiceContext(request: HttpRequest): ServiceRequestContext = - ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) + def extractTrackingId(request: HttpRequest): String = { + request.headers + .find(_.name == ContextHeaders.TrackingIdHeader) + .fold(java.util.UUID.randomUUID.toString)(_.value()) + } - def extractTrackingId(request: HttpRequest): String = { - request.headers - .find(_.name == ContextHeaders.TrackingIdHeader) - .fold(java.util.UUID.randomUUID.toString)(_.value()) - } + def extractContextHeaders(request: HttpRequest): Map[String, String] = { + request.headers.filter { h => + h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader + } map { header => + if (header.name === ContextHeaders.AuthenticationTokenHeader) { + header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim + } else { + header.name -> header.value + } + } toMap + } - def extractContextHeaders(request: HttpRequest): Map[String, String] = { - request.headers.filter { h => - h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader - } map { header => - if (header.name === ContextHeaders.AuthenticationTokenHeader) { - header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim - } else { - header.name -> header.value + private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { + def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { + val index = byteString.indexOf('/', from) + if (index === -1) descIndices.reverse + else { + val (init, tail) = byteString.splitAt(index) + if ((init endsWith "<") && (tail startsWith "/sc")) { + dirtyIndices(index + 1, index :: descIndices) + } else { + dirtyIndices(index + 1, descIndices) + } + } } - } toMap - } - private def escapeScriptTags(byteString: ByteString): ByteString = { - def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { - val index = byteString.indexOf('/', from) - if (index === -1) descIndices.reverse + val firstSlash = byteString.indexOf('/') + if (firstSlash === -1) byteString else { - val (init, tail) = byteString.splitAt(index) - if ((init endsWith "<") && (tail startsWith "/sc")) { - dirtyIndices(index + 1, index :: descIndices) - } else { - dirtyIndices(index + 1, descIndices) + val indices = dirtyIndices(firstSlash, Nil) :+ byteString.length + val builder = ByteString.newBuilder + builder ++= byteString.take(firstSlash) + indices.sliding(2).foreach { + case Seq(start, end) => + builder += ' ' + builder ++= byteString.slice(start, end) } + builder.result } } - val firstSlash = byteString.indexOf('/') - if (firstSlash === -1) byteString - else { - val indices = dirtyIndices(firstSlash, Nil) :+ byteString.length - val builder = ByteString.newBuilder - builder ++= byteString.take(firstSlash) - indices.sliding(2).foreach { - case Seq(start, end) => - builder += ' ' - builder ++= byteString.slice(start, end) - } - builder.result + val sanitizeRequestEntity: Directive0 = { + mapRequest( + request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) } } - val sanitizeRequestEntity: Directive0 = { - mapRequest( - request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) + final case class ServiceRequestContext(trackingId: String = generators.nextUuid().toString, + contextHeaders: Map[String, String] = Map.empty[String, String]) { + + def authToken: Option[AuthToken] = + contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) + } + + object ContextHeaders { + val AuthenticationTokenHeader = "Authorization" + val AuthenticationHeaderPrefix = "Bearer" + val TrackingIdHeader = "X-Trace" } object AuthProvider { -- cgit v1.2.3