From 53829db6555e91de8170b87ca8511d46adfc5442 Mon Sep 17 00:00:00 2001 From: adamw Date: Thu, 31 Aug 2017 15:20:33 +0200 Subject: Breaking redirect loops --- .../scala/com/softwaremill/sttp/SttpHandler.scala | 50 ++++++++++++++++------ .../scala/com/softwaremill/sttp/BasicTests.scala | 11 +++++ 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala b/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala index 2f44840..f52eef6 100644 --- a/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala +++ b/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala @@ -12,11 +12,16 @@ import scala.language.higherKinds */ trait SttpHandler[R[_], -S] { def send[T](request: Request[T, S]): R[Response[T]] = { + sendWithCounter(request, 0) + } + + private def sendWithCounter[T](request: Request[T, S], + redirects: Int): R[Response[T]] = { val resp = doSend(request) if (request.options.followRedirects) { responseMonad.flatMap(resp) { response: Response[T] => if (response.isRedirect) { - followRedirect(request, response) + followRedirect(request, response, redirects) } else { responseMonad.unit(response) } @@ -27,24 +32,39 @@ trait SttpHandler[R[_], -S] { } private def followRedirect[T](request: Request[T, S], - response: Response[T]): R[Response[T]] = { - def isRelative(uri: String) = !uri.contains("://") + response: Response[T], + redirects: Int): R[Response[T]] = { response.header(LocationHeader).fold(responseMonad.unit(response)) { loc => - val uri = if (isRelative(loc)) { - // using java's URI to resolve a relative URI - uri"${new URI(request.uri.toString).resolve(loc).toString}" + if (redirects >= SttpHandler.MaxRedirects) { + responseMonad.unit(Response(Left("Too many redirects"), 0, Nil, Nil)) } else { - uri"$loc" + followRedirect(request, response, redirects, loc) } + } + } - val redirectResponse = send(request.copy[Id, T, S](uri = uri)) + private def followRedirect[T](request: Request[T, S], + response: Response[T], + redirects: Int, + loc: String): R[Response[T]] = { - responseMonad.map(redirectResponse) { rr => - val responseNoBody = - response.copy(body = response.body.right.map(_ => ())) - rr.copy(history = responseNoBody :: rr.history) - } + def isRelative(uri: String) = !uri.contains("://") + + val uri = if (isRelative(loc)) { + // using java's URI to resolve a relative URI + uri"${new URI(request.uri.toString).resolve(loc).toString}" + } else { + uri"$loc" + } + + val redirectResponse = + sendWithCounter(request.copy[Id, T, S](uri = uri), redirects + 1) + + responseMonad.map(redirectResponse) { rr => + val responseNoBody = + response.copy(body = response.body.right.map(_ => ())) + rr.copy(history = responseNoBody :: rr.history) } } @@ -58,3 +78,7 @@ trait SttpHandler[R[_], -S] { */ def responseMonad: MonadError[R] } + +object SttpHandler { + private[sttp] val MaxRedirects = 32 +} diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala index f2639b5..d43c43f 100644 --- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala +++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala @@ -159,6 +159,9 @@ class BasicTests } ~ path("r4") { complete("819") + } ~ + path("loop") { + redirect("/redirect/loop", StatusCodes.Found) } } @@ -573,11 +576,13 @@ class BasicTests val r2 = sttp.post(uri"$endpoint/redirect/r2") val r3 = sttp.post(uri"$endpoint/redirect/r3") val r4response = "819" + val loop = sttp.post(uri"$endpoint/redirect/loop") name should "not redirect when redirects shouldn't be followed (temporary)" in { val resp = r1.followRedirects(false).send().force() resp.code should be(307) resp.body should be('left) + resp.history should be('empty) } name should "not redirect when redirects shouldn't be followed (permanent)" in { @@ -621,6 +626,12 @@ class BasicTests resp.history(1).code should be(308) resp.history(2).code should be(302) } + + name should "break redirect loops" in { + val resp = loop.send().force() + resp.code should be(0) + resp.history should have size (SttpHandler.MaxRedirects) + } } } -- cgit v1.2.3