diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main/scala/xyz/driver/core/rest.scala | 132 | ||||
-rw-r--r-- | src/test/scala/xyz/driver/core/RestTest.scala | 16 |
2 files changed, 105 insertions, 43 deletions
diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index ebdb1b9..fd86b33 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -5,9 +5,12 @@ import akka.http.scaladsl.Http import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers.{HttpChallenges, RawHeader} import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected +import akka.http.scaladsl.server.Directive0 import akka.http.scaladsl.unmarshalling.Unmarshal import akka.http.scaladsl.unmarshalling.Unmarshaller import akka.stream.ActorMaterializer +import akka.stream.scaladsl.Flow +import akka.util.ByteString import com.github.swagger.akka.model._ import com.github.swagger.akka.{HasActorSystem, SwaggerHttpService} import com.typesafe.config.Config @@ -23,11 +26,75 @@ import scala.util.{Failure, Success} import scalaz.Scalaz.{Id => _, _} import scalaz.{ListT, OptionT} -object rest { +package rest { - final case class ServiceRequestContext( - trackingId: String = generators.nextUuid().toString, - contextHeaders: Map[String, String] = Map.empty[String, String]) { + object `package` { + import akka.http.scaladsl.server._ + import Directives._ + + def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) + + def extractServiceContext(request: HttpRequest): ServiceRequestContext = + ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) + + def extractTrackingId(request: HttpRequest): String = { + request.headers + .find(_.name == ContextHeaders.TrackingIdHeader) + .fold(java.util.UUID.randomUUID.toString)(_.value()) + } + + def extractContextHeaders(request: HttpRequest): Map[String, String] = { + request.headers.filter { h => + h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader + } map { header => + if (header.name === ContextHeaders.AuthenticationTokenHeader) { + header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim + } else { + header.name -> header.value + } + } toMap + } + + private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { + @annotation.tailrec + def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { + val index = byteString.indexOf('/', from) + if (index === -1) descIndices.reverse + else { + val (init, tail) = byteString.splitAt(index) + if ((init endsWith "<") && (tail startsWith "/sc")) { + dirtyIndices(index + 1, index :: descIndices) + } else { + dirtyIndices(index + 1, descIndices) + } + } + } + + val indices = dirtyIndices(0, Nil) + + indices.headOption.fold(byteString){head => + val builder = ByteString.newBuilder + builder ++= byteString.take(head) + + (indices :+ byteString.length).sliding(2).foreach { + case Seq(start, end) => + builder += ' ' + builder ++= byteString.slice(start, end) + case Seq(byteStringLength) => // Should not match; sliding on at least 2 elements + assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") + } + builder.result + } + } + + val sanitizeRequestEntity: Directive0 = { + mapRequest( + request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) + } + } + + final case class ServiceRequestContext(trackingId: String = generators.nextUuid().toString, + contextHeaders: Map[String, String] = Map.empty[String, String]) { def authToken: Option[AuthToken] = contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) @@ -39,32 +106,6 @@ object rest { val TrackingIdHeader = "X-Trace" } - import akka.http.scaladsl.server._ - import Directives._ - - def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) - - def extractServiceContext(request: HttpRequest): ServiceRequestContext = - ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) - - def extractTrackingId(request: HttpRequest): String = { - request.headers - .find(_.name == ContextHeaders.TrackingIdHeader) - .fold(java.util.UUID.randomUUID.toString)(_.value()) - } - - def extractContextHeaders(request: HttpRequest): Map[String, String] = { - request.headers.filter { h => - h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader - } map { header => - if (header.name === ContextHeaders.AuthenticationTokenHeader) { - header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim - } else { - header.name -> header.value - } - } toMap - } - object AuthProvider { val AuthenticationTokenHeader = ContextHeaders.AuthenticationTokenHeader val SetAuthenticationTokenHeader = "set-authorization" @@ -75,7 +116,8 @@ object rest { } class AlwaysAllowAuthorization extends Authorization { - override def userHasPermission(user: User, permission: Permission)(implicit ctx: ServiceRequestContext): Future[Boolean] = { + override def userHasPermission(user: User, permission: Permission)( + implicit ctx: ServiceRequestContext): Future[Boolean] = { Future.successful(true) } } @@ -100,11 +142,9 @@ object rest { def authorize(permissions: Permission*): Directive1[U] = { serviceContext flatMap { ctx => - onComplete(authenticatedUser(ctx).run flatMap { userOption => userOption.traverse[Future, (U, Boolean)] { user => - permissions - .toList + permissions.toList .traverse[Future, Boolean](authorization.userHasPermission(user, _)(ctx)) .map(results => user -> results.forall(identity)) } @@ -150,11 +190,11 @@ object rest { OptionT[Future, Unit](request.flatMap(_.to[String]).map(_ => Option(()))) protected def optionalResponse[T](request: Future[Unmarshal[ResponseEntity]])( - implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] = + implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] = OptionT[Future, T](request.flatMap(_.fold(Option.empty[T]))) protected def listResponse[T](request: Future[Unmarshal[ResponseEntity]])( - implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] = + implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] = ListT[Future, T](request.flatMap(_.fold(List.empty[T]))) protected def jsonEntity(json: JsValue): RequestEntity = @@ -194,11 +234,15 @@ object rest { def discover[T <: Service](serviceName: Name[Service]): T } - class HttpRestServiceTransport(actorSystem: ActorSystem, executionContext: ExecutionContext, - log: Logger, stats: Stats, time: TimeProvider) extends ServiceTransport { + class HttpRestServiceTransport(actorSystem: ActorSystem, + executionContext: ExecutionContext, + log: Logger, + stats: Stats, + time: TimeProvider) + extends ServiceTransport { protected implicit val materializer = ActorMaterializer()(actorSystem) - protected implicit val execution = executionContext + protected implicit val execution = executionContext def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { @@ -206,7 +250,9 @@ object rest { val request = requestStub .withHeaders(RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId)) - .withHeaders(context.contextHeaders.toSeq.map { h => RawHeader(h._1, h._2): HttpHeader }: _*) + .withHeaders(context.contextHeaders.toSeq.map { h => + RawHeader(h._1, h._2): HttpHeader + }: _*) log.audit(s"Sending to ${request.uri} request $request with tracking id ${context.trackingId}") @@ -223,7 +269,7 @@ object rest { log.audit(s"Failed to receive response from ${request.uri} to request $requestStub", t) log.error(s"Failed to receive response from ${request.uri} to request $requestStub", t) stats.recordStats(Seq("request", request.uri.toString, "fail"), TimeRange(requestTime, responseTime), 1) - } (executionContext) + }(executionContext) response } @@ -231,9 +277,9 @@ object rest { def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = { sendRequestGetResponse(context)(requestStub) map { response => - if(response.status == StatusCodes.NotFound) { + if (response.status == StatusCodes.NotFound) { Unmarshal(HttpEntity.Empty: ResponseEntity) - } else if(response.status.isFailure()) { + } else if (response.status.isFailure()) { throw new Exception(s"Http status is failure ${response.status}") } else { Unmarshal(response.entity) diff --git a/src/test/scala/xyz/driver/core/RestTest.scala b/src/test/scala/xyz/driver/core/RestTest.scala new file mode 100644 index 0000000..efb9d07 --- /dev/null +++ b/src/test/scala/xyz/driver/core/RestTest.scala @@ -0,0 +1,16 @@ +package xyz.driver.core.rest + +import org.scalatest.{FlatSpec, Matchers} + +import akka.util.ByteString + +class RestTest extends FlatSpec with Matchers { + "`escapeScriptTags` function" should "escap script tags properly" in { + val dirtyString = "</sc----</sc----</sc" + val cleanString = "--------------------" + + (escapeScriptTags(ByteString(dirtyString)).utf8String) should be(dirtyString.replace("</sc", "< /sc")) + + (escapeScriptTags(ByteString(cleanString)).utf8String) should be(cleanString) + } +} |