From c9e16c212c13929aea51b2362e3bc5f09b9a449d Mon Sep 17 00:00:00 2001 From: Stewart Stewart Date: Sun, 19 Mar 2017 16:50:52 -0400 Subject: scalafmt --- src/main/scala/xyz/driver/core/rest.scala | 36 +++++++++++++++++-------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index ebdb1b9..9923ef8 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -25,9 +25,8 @@ import scalaz.{ListT, OptionT} object rest { - final case class ServiceRequestContext( - trackingId: String = generators.nextUuid().toString, - contextHeaders: Map[String, String] = Map.empty[String, String]) { + final case class ServiceRequestContext(trackingId: String = generators.nextUuid().toString, + contextHeaders: Map[String, String] = Map.empty[String, String]) { def authToken: Option[AuthToken] = contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) @@ -75,7 +74,8 @@ object rest { } class AlwaysAllowAuthorization extends Authorization { - override def userHasPermission(user: User, permission: Permission)(implicit ctx: ServiceRequestContext): Future[Boolean] = { + override def userHasPermission(user: User, permission: Permission)( + implicit ctx: ServiceRequestContext): Future[Boolean] = { Future.successful(true) } } @@ -100,11 +100,9 @@ object rest { def authorize(permissions: Permission*): Directive1[U] = { serviceContext flatMap { ctx => - onComplete(authenticatedUser(ctx).run flatMap { userOption => userOption.traverse[Future, (U, Boolean)] { user => - permissions - .toList + permissions.toList .traverse[Future, Boolean](authorization.userHasPermission(user, _)(ctx)) .map(results => user -> results.forall(identity)) } @@ -150,11 +148,11 @@ object rest { OptionT[Future, Unit](request.flatMap(_.to[String]).map(_ => Option(()))) protected def optionalResponse[T](request: Future[Unmarshal[ResponseEntity]])( - implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] = + implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] = OptionT[Future, T](request.flatMap(_.fold(Option.empty[T]))) protected def listResponse[T](request: Future[Unmarshal[ResponseEntity]])( - implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] = + implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] = ListT[Future, T](request.flatMap(_.fold(List.empty[T]))) protected def jsonEntity(json: JsValue): RequestEntity = @@ -194,11 +192,15 @@ object rest { def discover[T <: Service](serviceName: Name[Service]): T } - class HttpRestServiceTransport(actorSystem: ActorSystem, executionContext: ExecutionContext, - log: Logger, stats: Stats, time: TimeProvider) extends ServiceTransport { + class HttpRestServiceTransport(actorSystem: ActorSystem, + executionContext: ExecutionContext, + log: Logger, + stats: Stats, + time: TimeProvider) + extends ServiceTransport { protected implicit val materializer = ActorMaterializer()(actorSystem) - protected implicit val execution = executionContext + protected implicit val execution = executionContext def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { @@ -206,7 +208,9 @@ object rest { val request = requestStub .withHeaders(RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId)) - .withHeaders(context.contextHeaders.toSeq.map { h => RawHeader(h._1, h._2): HttpHeader }: _*) + .withHeaders(context.contextHeaders.toSeq.map { h => + RawHeader(h._1, h._2): HttpHeader + }: _*) log.audit(s"Sending to ${request.uri} request $request with tracking id ${context.trackingId}") @@ -223,7 +227,7 @@ object rest { log.audit(s"Failed to receive response from ${request.uri} to request $requestStub", t) log.error(s"Failed to receive response from ${request.uri} to request $requestStub", t) stats.recordStats(Seq("request", request.uri.toString, "fail"), TimeRange(requestTime, responseTime), 1) - } (executionContext) + }(executionContext) response } @@ -231,9 +235,9 @@ object rest { def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = { sendRequestGetResponse(context)(requestStub) map { response => - if(response.status == StatusCodes.NotFound) { + if (response.status == StatusCodes.NotFound) { Unmarshal(HttpEntity.Empty: ResponseEntity) - } else if(response.status.isFailure()) { + } else if (response.status.isFailure()) { throw new Exception(s"Http status is failure ${response.status}") } else { Unmarshal(response.entity) -- cgit v1.2.3 From 0ef027120ec5857ce4c2a86c2aad5a7d925cf573 Mon Sep 17 00:00:00 2001 From: Stewart Stewart Date: Sun, 19 Mar 2017 16:54:55 -0400 Subject: add directive for sanitizing script tags request entities --- src/main/scala/xyz/driver/core/rest.scala | 37 +++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index 9923ef8..17837e6 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -5,9 +5,12 @@ import akka.http.scaladsl.Http import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers.{HttpChallenges, RawHeader} import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected +import akka.http.scaladsl.server.Directive0 import akka.http.scaladsl.unmarshalling.Unmarshal import akka.http.scaladsl.unmarshalling.Unmarshaller import akka.stream.ActorMaterializer +import akka.stream.scaladsl.Flow +import akka.util.ByteString import com.github.swagger.akka.model._ import com.github.swagger.akka.{HasActorSystem, SwaggerHttpService} import com.typesafe.config.Config @@ -64,6 +67,40 @@ object rest { } toMap } + private def escapeScriptTags(byteString: ByteString): ByteString = { + def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { + val index = byteString.indexOf('/', from) + if (index === -1) descIndices.reverse + else { + val (init, tail) = byteString.splitAt(index) + if ((init endsWith "<") && (tail startsWith "/sc")) { + dirtyIndices(index + 1, index :: descIndices) + } else { + dirtyIndices(index + 1, descIndices) + } + } + } + + val firstSlash = byteString.indexOf('/') + if (firstSlash === -1) byteString + else { + val indices = dirtyIndices(firstSlash, Nil) :+ byteString.length + val builder = ByteString.newBuilder + builder ++= byteString.take(firstSlash) + indices.sliding(2).foreach { + case Seq(start, end) => + builder += ' ' + builder ++= byteString.slice(start, end) + } + builder.result + } + } + + val sanitizeRequestEntity: Directive0 = { + mapRequest( + request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) + } + object AuthProvider { val AuthenticationTokenHeader = ContextHeaders.AuthenticationTokenHeader val SetAuthenticationTokenHeader = "set-authorization" -- cgit v1.2.3 From eba3d78fd8533703925c7f4d3550ad0c80bbc572 Mon Sep 17 00:00:00 2001 From: Stewart Stewart Date: Mon, 20 Mar 2017 13:50:13 -0400 Subject: turn object xyz.driver.core.rest into package --- src/main/scala/xyz/driver/core/rest.scala | 116 +++++++++++++++--------------- 1 file changed, 59 insertions(+), 57 deletions(-) diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index 17837e6..7c9e1d4 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -26,79 +26,81 @@ import scala.util.{Failure, Success} import scalaz.Scalaz.{Id => _, _} import scalaz.{ListT, OptionT} -object rest { +package rest { - final case class ServiceRequestContext(trackingId: String = generators.nextUuid().toString, - contextHeaders: Map[String, String] = Map.empty[String, String]) { - - def authToken: Option[AuthToken] = - contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) - } - - object ContextHeaders { - val AuthenticationTokenHeader = "Authorization" - val AuthenticationHeaderPrefix = "Bearer" - val TrackingIdHeader = "X-Trace" - } + object `package` { + import akka.http.scaladsl.server._ + import Directives._ - import akka.http.scaladsl.server._ - import Directives._ + def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) - def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request)) + def extractServiceContext(request: HttpRequest): ServiceRequestContext = + ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) - def extractServiceContext(request: HttpRequest): ServiceRequestContext = - ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request)) + def extractTrackingId(request: HttpRequest): String = { + request.headers + .find(_.name == ContextHeaders.TrackingIdHeader) + .fold(java.util.UUID.randomUUID.toString)(_.value()) + } - def extractTrackingId(request: HttpRequest): String = { - request.headers - .find(_.name == ContextHeaders.TrackingIdHeader) - .fold(java.util.UUID.randomUUID.toString)(_.value()) - } + def extractContextHeaders(request: HttpRequest): Map[String, String] = { + request.headers.filter { h => + h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader + } map { header => + if (header.name === ContextHeaders.AuthenticationTokenHeader) { + header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim + } else { + header.name -> header.value + } + } toMap + } - def extractContextHeaders(request: HttpRequest): Map[String, String] = { - request.headers.filter { h => - h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader - } map { header => - if (header.name === ContextHeaders.AuthenticationTokenHeader) { - header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim - } else { - header.name -> header.value + private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { + def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { + val index = byteString.indexOf('/', from) + if (index === -1) descIndices.reverse + else { + val (init, tail) = byteString.splitAt(index) + if ((init endsWith "<") && (tail startsWith "/sc")) { + dirtyIndices(index + 1, index :: descIndices) + } else { + dirtyIndices(index + 1, descIndices) + } + } } - } toMap - } - private def escapeScriptTags(byteString: ByteString): ByteString = { - def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { - val index = byteString.indexOf('/', from) - if (index === -1) descIndices.reverse + val firstSlash = byteString.indexOf('/') + if (firstSlash === -1) byteString else { - val (init, tail) = byteString.splitAt(index) - if ((init endsWith "<") && (tail startsWith "/sc")) { - dirtyIndices(index + 1, index :: descIndices) - } else { - dirtyIndices(index + 1, descIndices) + val indices = dirtyIndices(firstSlash, Nil) :+ byteString.length + val builder = ByteString.newBuilder + builder ++= byteString.take(firstSlash) + indices.sliding(2).foreach { + case Seq(start, end) => + builder += ' ' + builder ++= byteString.slice(start, end) } + builder.result } } - val firstSlash = byteString.indexOf('/') - if (firstSlash === -1) byteString - else { - val indices = dirtyIndices(firstSlash, Nil) :+ byteString.length - val builder = ByteString.newBuilder - builder ++= byteString.take(firstSlash) - indices.sliding(2).foreach { - case Seq(start, end) => - builder += ' ' - builder ++= byteString.slice(start, end) - } - builder.result + val sanitizeRequestEntity: Directive0 = { + mapRequest( + request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) } } - val sanitizeRequestEntity: Directive0 = { - mapRequest( - request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) + final case class ServiceRequestContext(trackingId: String = generators.nextUuid().toString, + contextHeaders: Map[String, String] = Map.empty[String, String]) { + + def authToken: Option[AuthToken] = + contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) + } + + object ContextHeaders { + val AuthenticationTokenHeader = "Authorization" + val AuthenticationHeaderPrefix = "Bearer" + val TrackingIdHeader = "X-Trace" } object AuthProvider { -- cgit v1.2.3 From e63fc3ef064542fd3199b1c6ebee20cc7f1f18e7 Mon Sep 17 00:00:00 2001 From: Stewart Stewart Date: Mon, 20 Mar 2017 13:50:37 -0400 Subject: add test for script tag sanitizer --- src/test/scala/xyz/driver/core/RestTest.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 src/test/scala/xyz/driver/core/RestTest.scala diff --git a/src/test/scala/xyz/driver/core/RestTest.scala b/src/test/scala/xyz/driver/core/RestTest.scala new file mode 100644 index 0000000..efb9d07 --- /dev/null +++ b/src/test/scala/xyz/driver/core/RestTest.scala @@ -0,0 +1,16 @@ +package xyz.driver.core.rest + +import org.scalatest.{FlatSpec, Matchers} + +import akka.util.ByteString + +class RestTest extends FlatSpec with Matchers { + "`escapeScriptTags` function" should "escap script tags properly" in { + val dirtyString = " Date: Tue, 21 Mar 2017 20:57:45 -0400 Subject: add tailrec annotation --- src/main/scala/xyz/driver/core/rest.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index 7c9e1d4..498ba33 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -56,6 +56,7 @@ package rest { } private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { + @annotation.tailrec def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { val index = byteString.indexOf('/', from) if (index === -1) descIndices.reverse -- cgit v1.2.3 From 61fe057cd3651773b1ac353d33ea60d6626d4ec3 Mon Sep 17 00:00:00 2001 From: Stewart Stewart Date: Tue, 21 Mar 2017 21:22:18 -0400 Subject: slight refactor of escapeScriptTags --- src/main/scala/xyz/driver/core/rest.scala | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index 498ba33..fd86b33 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -70,19 +70,21 @@ package rest { } } - val firstSlash = byteString.indexOf('/') - if (firstSlash === -1) byteString - else { - val indices = dirtyIndices(firstSlash, Nil) :+ byteString.length - val builder = ByteString.newBuilder - builder ++= byteString.take(firstSlash) - indices.sliding(2).foreach { - case Seq(start, end) => - builder += ' ' - builder ++= byteString.slice(start, end) - } - builder.result - } + val indices = dirtyIndices(0, Nil) + + indices.headOption.fold(byteString){head => + val builder = ByteString.newBuilder + builder ++= byteString.take(head) + + (indices :+ byteString.length).sliding(2).foreach { + case Seq(start, end) => + builder += ' ' + builder ++= byteString.slice(start, end) + case Seq(byteStringLength) => // Should not match; sliding on at least 2 elements + assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") + } + builder.result + } } val sanitizeRequestEntity: Directive0 = { -- cgit v1.2.3