aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStewart Stewart <stewinsalot@gmail.com>2017-03-23 14:48:00 -0700
committerGitHub <noreply@github.com>2017-03-23 14:48:00 -0700
commit3fccf2ef885e944c3bab1b409f343ee0ac177d87 (patch)
tree99be931fbc7217b14bf14721f72d6b1ebf69b497
parentdcceb9aae8073617f43335c83647af5ccf8685ef (diff)
parent61fe057cd3651773b1ac353d33ea60d6626d4ec3 (diff)
downloaddriver-core-0.10.32.tar.gz
driver-core-0.10.32.tar.bz2
driver-core-0.10.32.zip
Merge pull request #27 from drivergroup/xss-escape-directivev0.10.32
Add directive for escaping script tags in request entities
-rw-r--r--src/main/scala/xyz/driver/core/rest.scala132
-rw-r--r--src/test/scala/xyz/driver/core/RestTest.scala16
2 files changed, 105 insertions, 43 deletions
diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala
index ebdb1b9..fd86b33 100644
--- a/src/main/scala/xyz/driver/core/rest.scala
+++ b/src/main/scala/xyz/driver/core/rest.scala
@@ -5,9 +5,12 @@ import akka.http.scaladsl.Http
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.{HttpChallenges, RawHeader}
import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected
+import akka.http.scaladsl.server.Directive0
import akka.http.scaladsl.unmarshalling.Unmarshal
import akka.http.scaladsl.unmarshalling.Unmarshaller
import akka.stream.ActorMaterializer
+import akka.stream.scaladsl.Flow
+import akka.util.ByteString
import com.github.swagger.akka.model._
import com.github.swagger.akka.{HasActorSystem, SwaggerHttpService}
import com.typesafe.config.Config
@@ -23,11 +26,75 @@ import scala.util.{Failure, Success}
import scalaz.Scalaz.{Id => _, _}
import scalaz.{ListT, OptionT}
-object rest {
+package rest {
- final case class ServiceRequestContext(
- trackingId: String = generators.nextUuid().toString,
- contextHeaders: Map[String, String] = Map.empty[String, String]) {
+ object `package` {
+ import akka.http.scaladsl.server._
+ import Directives._
+
+ def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request))
+
+ def extractServiceContext(request: HttpRequest): ServiceRequestContext =
+ ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request))
+
+ def extractTrackingId(request: HttpRequest): String = {
+ request.headers
+ .find(_.name == ContextHeaders.TrackingIdHeader)
+ .fold(java.util.UUID.randomUUID.toString)(_.value())
+ }
+
+ def extractContextHeaders(request: HttpRequest): Map[String, String] = {
+ request.headers.filter { h =>
+ h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader
+ } map { header =>
+ if (header.name === ContextHeaders.AuthenticationTokenHeader) {
+ header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim
+ } else {
+ header.name -> header.value
+ }
+ } toMap
+ }
+
+ private[rest] def escapeScriptTags(byteString: ByteString): ByteString = {
+ @annotation.tailrec
+ def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = {
+ val index = byteString.indexOf('/', from)
+ if (index === -1) descIndices.reverse
+ else {
+ val (init, tail) = byteString.splitAt(index)
+ if ((init endsWith "<") && (tail startsWith "/sc")) {
+ dirtyIndices(index + 1, index :: descIndices)
+ } else {
+ dirtyIndices(index + 1, descIndices)
+ }
+ }
+ }
+
+ val indices = dirtyIndices(0, Nil)
+
+ indices.headOption.fold(byteString){head =>
+ val builder = ByteString.newBuilder
+ builder ++= byteString.take(head)
+
+ (indices :+ byteString.length).sliding(2).foreach {
+ case Seq(start, end) =>
+ builder += ' '
+ builder ++= byteString.slice(start, end)
+ case Seq(byteStringLength) => // Should not match; sliding on at least 2 elements
+ assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices")
+ }
+ builder.result
+ }
+ }
+
+ val sanitizeRequestEntity: Directive0 = {
+ mapRequest(
+ request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags))))
+ }
+ }
+
+ final case class ServiceRequestContext(trackingId: String = generators.nextUuid().toString,
+ contextHeaders: Map[String, String] = Map.empty[String, String]) {
def authToken: Option[AuthToken] =
contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply)
@@ -39,32 +106,6 @@ object rest {
val TrackingIdHeader = "X-Trace"
}
- import akka.http.scaladsl.server._
- import Directives._
-
- def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx.request))
-
- def extractServiceContext(request: HttpRequest): ServiceRequestContext =
- ServiceRequestContext(extractTrackingId(request), extractContextHeaders(request))
-
- def extractTrackingId(request: HttpRequest): String = {
- request.headers
- .find(_.name == ContextHeaders.TrackingIdHeader)
- .fold(java.util.UUID.randomUUID.toString)(_.value())
- }
-
- def extractContextHeaders(request: HttpRequest): Map[String, String] = {
- request.headers.filter { h =>
- h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader
- } map { header =>
- if (header.name === ContextHeaders.AuthenticationTokenHeader) {
- header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim
- } else {
- header.name -> header.value
- }
- } toMap
- }
-
object AuthProvider {
val AuthenticationTokenHeader = ContextHeaders.AuthenticationTokenHeader
val SetAuthenticationTokenHeader = "set-authorization"
@@ -75,7 +116,8 @@ object rest {
}
class AlwaysAllowAuthorization extends Authorization {
- override def userHasPermission(user: User, permission: Permission)(implicit ctx: ServiceRequestContext): Future[Boolean] = {
+ override def userHasPermission(user: User, permission: Permission)(
+ implicit ctx: ServiceRequestContext): Future[Boolean] = {
Future.successful(true)
}
}
@@ -100,11 +142,9 @@ object rest {
def authorize(permissions: Permission*): Directive1[U] = {
serviceContext flatMap { ctx =>
-
onComplete(authenticatedUser(ctx).run flatMap { userOption =>
userOption.traverse[Future, (U, Boolean)] { user =>
- permissions
- .toList
+ permissions.toList
.traverse[Future, Boolean](authorization.userHasPermission(user, _)(ctx))
.map(results => user -> results.forall(identity))
}
@@ -150,11 +190,11 @@ object rest {
OptionT[Future, Unit](request.flatMap(_.to[String]).map(_ => Option(())))
protected def optionalResponse[T](request: Future[Unmarshal[ResponseEntity]])(
- implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] =
+ implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] =
OptionT[Future, T](request.flatMap(_.fold(Option.empty[T])))
protected def listResponse[T](request: Future[Unmarshal[ResponseEntity]])(
- implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] =
+ implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] =
ListT[Future, T](request.flatMap(_.fold(List.empty[T])))
protected def jsonEntity(json: JsValue): RequestEntity =
@@ -194,11 +234,15 @@ object rest {
def discover[T <: Service](serviceName: Name[Service]): T
}
- class HttpRestServiceTransport(actorSystem: ActorSystem, executionContext: ExecutionContext,
- log: Logger, stats: Stats, time: TimeProvider) extends ServiceTransport {
+ class HttpRestServiceTransport(actorSystem: ActorSystem,
+ executionContext: ExecutionContext,
+ log: Logger,
+ stats: Stats,
+ time: TimeProvider)
+ extends ServiceTransport {
protected implicit val materializer = ActorMaterializer()(actorSystem)
- protected implicit val execution = executionContext
+ protected implicit val execution = executionContext
def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = {
@@ -206,7 +250,9 @@ object rest {
val request = requestStub
.withHeaders(RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId))
- .withHeaders(context.contextHeaders.toSeq.map { h => RawHeader(h._1, h._2): HttpHeader }: _*)
+ .withHeaders(context.contextHeaders.toSeq.map { h =>
+ RawHeader(h._1, h._2): HttpHeader
+ }: _*)
log.audit(s"Sending to ${request.uri} request $request with tracking id ${context.trackingId}")
@@ -223,7 +269,7 @@ object rest {
log.audit(s"Failed to receive response from ${request.uri} to request $requestStub", t)
log.error(s"Failed to receive response from ${request.uri} to request $requestStub", t)
stats.recordStats(Seq("request", request.uri.toString, "fail"), TimeRange(requestTime, responseTime), 1)
- } (executionContext)
+ }(executionContext)
response
}
@@ -231,9 +277,9 @@ object rest {
def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = {
sendRequestGetResponse(context)(requestStub) map { response =>
- if(response.status == StatusCodes.NotFound) {
+ if (response.status == StatusCodes.NotFound) {
Unmarshal(HttpEntity.Empty: ResponseEntity)
- } else if(response.status.isFailure()) {
+ } else if (response.status.isFailure()) {
throw new Exception(s"Http status is failure ${response.status}")
} else {
Unmarshal(response.entity)
diff --git a/src/test/scala/xyz/driver/core/RestTest.scala b/src/test/scala/xyz/driver/core/RestTest.scala
new file mode 100644
index 0000000..efb9d07
--- /dev/null
+++ b/src/test/scala/xyz/driver/core/RestTest.scala
@@ -0,0 +1,16 @@
+package xyz.driver.core.rest
+
+import org.scalatest.{FlatSpec, Matchers}
+
+import akka.util.ByteString
+
+class RestTest extends FlatSpec with Matchers {
+ "`escapeScriptTags` function" should "escap script tags properly" in {
+ val dirtyString = "</sc----</sc----</sc"
+ val cleanString = "--------------------"
+
+ (escapeScriptTags(ByteString(dirtyString)).utf8String) should be(dirtyString.replace("</sc", "< /sc"))
+
+ (escapeScriptTags(ByteString(cleanString)).utf8String) should be(cleanString)
+ }
+}